From: Andrew Reynolds Date: Thu, 12 Dec 2019 20:38:42 +0000 (-0600) Subject: Use the node-level datatypes API (#3556) X-Git-Tag: cvc5-1.0.0~3771 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=e6d20229cf21a3614ac546514f42bd06002d52b8;p=cvc5.git Use the node-level datatypes API (#3556) --- diff --git a/src/expr/datatype.cpp b/src/expr/datatype.cpp index 14b21a96a..7e5fb6d7d 100644 --- a/src/expr/datatype.cpp +++ b/src/expr/datatype.cpp @@ -835,6 +835,7 @@ std::ostream& operator<<(std::ostream& out, const DatatypeIndexConstant& dic) { return out << "index_" << dic.getIndex(); } + std::string Datatype::getName() const { ExprManagerScope ems(*d_em); diff --git a/src/expr/expr_manager_template.cpp b/src/expr/expr_manager_template.cpp index 411d64a1b..1981d0a7d 100644 --- a/src/expr/expr_manager_template.cpp +++ b/src/expr/expr_manager_template.cpp @@ -681,7 +681,7 @@ std::vector ExprManager::mkMutualDatatypeTypes( for(std::vector::iterator i = datatypes.begin(), i_end = datatypes.end(); i != i_end; ++i) { dt_copies.push_back( new Datatype( *i ) ); } - + // First do some sanity checks, set up the final Type to be used for // each datatype, and set up the "named resolutions" used to handle // simple self- and mutual-recursion, for example in the definition diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 080306d39..42f4d3e06 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -16,6 +16,7 @@ #include "theory/datatypes/datatypes_rewriter.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" #include "options/datatypes_options.h" @@ -59,8 +60,8 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) } TNode constructor = in[0].getOperator(); size_t constructorIndex = utils::indexOf(constructor); - const Datatype& dt = Datatype::datatypeOf(constructor.toExpr()); - const DatatypeConstructor& c = dt[constructorIndex]; + const DType& dt = utils::datatypeOf(constructor); + const DTypeConstructor& c = dt[constructorIndex]; unsigned weight = c.getWeight(); children.push_back(nm->mkConst(Rational(weight))); Node res = @@ -140,7 +141,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) std::vector cases; std::vector rets; TypeNode t = h.getType(); - const Datatype& dt = t.getDatatype(); + const DType& dt = t.getDType(); for (size_t k = 1, nchild = in.getNumChildren(); k < nchild; k++) { Node c = in[k]; @@ -166,7 +167,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) // cons is null in the default case if (!cons.isNull()) { - cindex = Datatype::indexOf(cons.toExpr()); + cindex = utils::indexOf(cons); } Node body; if (ck == MATCH_CASE) @@ -189,9 +190,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in) { vars.push_back(c[0][i]); Node sc = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(t.toType(), i)), - h); + APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(t, i), h); subs.push_back(sc); } } @@ -264,13 +263,11 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in) if (in.getKind() == kind::APPLY_CONSTRUCTOR) { TypeNode tn = in.getType(); - Type t = tn.toType(); - DatatypeType dt = DatatypeType(t); // check for parametric datatype constructors // to ensure a normal form, all parameteric datatype constructors must have // a type ascription - if (dt.isParametric()) + if (tn.isParametricDatatype()) { if (in.getOperator().getKind() != kind::APPLY_TYPE_ASCRIPTION) { @@ -279,11 +276,10 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in) << std::endl; Node op = in.getOperator(); // get the constructor object - const DatatypeConstructor& dtc = - Datatype::datatypeOf(op.toExpr())[utils::indexOf(op)]; + const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)]; // create ascribed constructor type Node tc = NodeManager::currentNM()->mkConst( - AscriptionType(dtc.getSpecializedConstructorType(t))); + AscriptionType(dtc.getSpecializedConstructorType(tn).toType())); Node op_new = NodeManager::currentNM()->mkNode( kind::APPLY_TYPE_ASCRIPTION, tc, op); // make new node @@ -331,11 +327,11 @@ RewriteResponse DatatypesRewriter::rewriteSelector(TNode in) // e.g. "pred(zero)". TypeNode tn = in.getType(); TypeNode argType = in[0].getType(); - Expr selector = in.getOperator().toExpr(); + Node selector = in.getOperator(); TNode constructor = in[0].getOperator(); size_t constructorIndex = utils::indexOf(constructor); - const Datatype& dt = Datatype::datatypeOf(selector); - const DatatypeConstructor& c = dt[constructorIndex]; + const DType& dt = utils::datatypeOf(selector); + const DTypeConstructor& c = dt[constructorIndex]; Trace("datatypes-rewrite-debug") << "Rewriting collapsable selector : " << in; Trace("datatypes-rewrite-debug") << ", cindex = " << constructorIndex @@ -355,7 +351,7 @@ RewriteResponse DatatypesRewriter::rewriteSelector(TNode in) // The argument index of external selectors (applications of // APPLY_SELECTOR) is given by an attribute and obtained via indexOf below // The argument is only valid if it is the proper constructor. - selectorIndex = Datatype::indexOf(selector); + selectorIndex = utils::indexOf(selector); if (selectorIndex < 0 || selectorIndex >= static_cast(c.getNumArgs())) { @@ -400,7 +396,7 @@ RewriteResponse DatatypesRewriter::rewriteSelector(TNode in) //} if (tn.isDatatype()) { - const Datatype& dta = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dta = tn.getDType(); useTe = !dta.isCodatatype(); } if (useTe) @@ -445,7 +441,7 @@ RewriteResponse DatatypesRewriter::rewriteTester(TNode in) return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(result)); } - const Datatype& dt = static_cast(in[0].getType().toType()).getDatatype(); + const DType& dt = in[0].getType().getDType(); if (dt.getNumConstructors() == 1 && !dt.isSygus()) { // only one constructor, so it must be diff --git a/src/theory/datatypes/sygus_extension.cpp b/src/theory/datatypes/sygus_extension.cpp index ceb5d2dab..42fb5cd07 100644 --- a/src/theory/datatypes/sygus_extension.cpp +++ b/src/theory/datatypes/sygus_extension.cpp @@ -14,6 +14,7 @@ #include "theory/datatypes/sygus_extension.h" +#include "expr/dtype.h" #include "expr/node_manager.h" #include "expr/sygus_datatype.h" #include "options/base_options.h" @@ -77,8 +78,9 @@ void SygusExtension::assertTester( int tindex, TNode n, Node exp, std::vector< N Assert(itt != d_testers.end()); int ptindex = (*itt).second; TypeNode ptn = n[0].getType(); - const Datatype& pdt = ((DatatypeType)ptn.toType()).getDatatype(); - int sindex_in_parent = pdt[ptindex].getSelectorIndexInternal( n.getOperator().toExpr() ); + const DType& pdt = ptn.getDType(); + int sindex_in_parent = + pdt[ptindex].getSelectorIndexInternal(n.getOperator()); // the tester is irrelevant in this branch if( sindex_in_parent==-1 ){ do_add = false; @@ -138,7 +140,7 @@ Node SygusExtension::getTermOrderPredicate( Node n1, Node n2 ) { sz_eq_cases.push_back( sz_eq ); if( options::sygusOpt1() ){ TypeNode tnc = n1.getType(); - const Datatype& cdt = ((DatatypeType)(tnc).toType()).getDatatype(); + const DType& cdt = tnc.getDType(); for( unsigned j=0; j case_conj; for (unsigned k = 0; k < j; k++) @@ -194,7 +196,10 @@ void SygusExtension::registerTerm( Node n, std::vector< Node >& lemmas ) { } } if( success ){ - Trace("sygus-sb-debug") << "Register : " << n << ", depth : " << d << ", top level = " << is_top_level << ", type = " << ((DatatypeType)tn.toType()).getDatatype().getName() << std::endl; + Trace("sygus-sb-debug") + << "Register : " << n << ", depth : " << d + << ", top level = " << is_top_level + << ", type = " << tn.getDType().getName() << std::endl; d_term_to_depth[n] = d; d_is_top_level[n] = is_top_level; registerSearchTerm( tn, d, n, is_top_level, lemmas ); @@ -221,7 +226,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v // nothing to do for non-datatype types return; } - const Datatype& dt = static_cast(ntn.toType()).getDatatype(); + const DType& dt = ntn.getDType(); if (!dt.isSygus()) { // nothing to do for non-sygus-datatype type @@ -265,7 +270,7 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v IntMap::const_iterator ittv = d_testers.find( x ); Assert(ittv != d_testers.end()); int tindex = (*ittv).second; - const Datatype& dti = ((DatatypeType)x.getType().toType()).getDatatype(); + const DType& dti = x.getType().getDType(); if( dti[tindex].getNumArgs()>0 ){ NodeMap::const_iterator itt = d_testers_exp.find( x ); Assert(itt != d_testers_exp.end()); @@ -365,7 +370,8 @@ void SygusExtension::assertTesterInternal( int tindex, TNode n, Node exp, std::v if( options::sygusSymBreakLazy() ){ Trace("sygus-sb-debug") << "Do lazy symmetry breaking...\n"; for( unsigned j=0; jmkNode( APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex].getSelectorInternal( ntn.toType(), j ) ), n ); + Node sel = nm->mkNode( + APPLY_SELECTOR_TOTAL, dt[tindex].getSelectorInternal(ntn, j), n); Trace("sygus-sb-debug2") << " activate child sel : " << sel << std::endl; Assert(d_active_terms.find(sel) == d_active_terms.end()); IntMap::const_iterator itt = d_testers.find( sel ); @@ -384,14 +390,13 @@ Node SygusExtension::getRelevancyCondition( Node n ) { Node cond; if( n.getKind()==APPLY_SELECTOR_TOTAL && options::sygusSymBreakRlv() ){ TypeNode ntn = n[0].getType(); - Type nt = ntn.toType(); - const Datatype& dt = ((DatatypeType)nt).getDatatype(); - Expr selExpr = n.getOperator().toExpr(); + const DType& dt = ntn.getDType(); + Node sel = n.getOperator(); if( options::dtSharedSelectors() ){ std::vector< Node > disj; bool excl = false; for( unsigned i=0; i(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); Assert(tindex >= 0 && tindex < static_cast(dt.getNumConstructors())); @@ -561,7 +566,7 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, quantifiers::SygusTypeInfo& ti = d_tds->getTypeInfo(tn); // get the sygus operator - Node sop = Node::fromExpr(dt[tindex].getSygusOp()); + Node sop = dt[tindex].getSygusOp(); // get the kind of the constructor operator Kind nk = ti.getConsNumKind(tindex); // is this the any-constant constructor? @@ -576,15 +581,13 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, unsigned dt_index_nargs = isAnyConstant ? 0 : dt[tindex].getNumArgs(); // builtin type - TypeNode tnb = TypeNode::fromType(dt.getSygusType()); + TypeNode tnb = dt.getSygusType(); // get children std::vector children; for (unsigned j = 0; j < dt_index_nargs; j++) { Node sel = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[tindex].getSelectorInternal(tn.toType(), j)), - n); + APPLY_SELECTOR_TOTAL, dt[tindex].getSelectorInternal(tn, j), n); Assert(sel.getType().isDatatype()); children.push_back(sel); } @@ -615,7 +618,7 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, // is the tindex^{th} constructor of dt. Thus, is-x_i( z ) is either // true or false below. - Node svl = Node::fromExpr(dt.getSygusVarList()); + Node svl = dt.getSygusVarList(); // for each variable Assert(!e.isNull()); TypeNode etn = e.getType(); @@ -699,8 +702,7 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, { children_solved[j] = i; TypeNode ctn = children[j].getType(); - const Datatype& cdt = - static_cast(ctn.toType()).getDatatype(); + const DType& cdt = ctn.getDType(); Assert(i < static_cast(cdt.getNumConstructors())); sbp_conj.push_back(utils::mkTester(children[j], i, cdt)); } @@ -820,8 +822,7 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, TypeNode tnc = nc.getType(); quantifiers::SygusTypeInfo& cti = d_tds->getTypeInfo(tnc); int anyc_cons_num = cti.getAnyConstantConsNum(); - const Datatype& cdt = - static_cast(tnc.toType()).getDatatype(); + const DType& cdt = tnc.getDType(); std::vector exp_const; for (unsigned k = 0, ncons = cdt.getNumConstructors(); k < ncons; k++) { @@ -909,10 +910,9 @@ Node SygusExtension::getSimpleSymBreakPred(Node e, && children[0].getType() == tn && children[1].getType() == tn) { // chainable - Node child11 = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[tindex].getSelectorInternal(tn.toType(), 1)), - children[0]); + Node child11 = nm->mkNode(APPLY_SELECTOR_TOTAL, + dt[tindex].getSelectorInternal(tn, 1), + children[0]); Assert(child11.getType() == children[1].getType()); Node order_pred_trans = nm->mkNode(OR, @@ -974,7 +974,7 @@ Node SygusExtension::registerSearchValue(Node a, // selector chain n. return n; } - const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { // don't register non-sygus-datatype terms @@ -992,9 +992,7 @@ Node SygusExtension::registerSearchValue(Node a, for (unsigned i = 0, nchild = nv.getNumChildren(); i < nchild; i++) { Node sel = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(tn.toType(), i)), - n); + APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(tn, i), n); Node nvc = registerSearchValue(a, sel, nv[i], @@ -1283,7 +1281,7 @@ void SygusExtension::registerSizeTerm(Node e, std::vector& lemmas) d_register_st[e] = false; return; } - const Datatype& dt = etn.getDatatype(); + const DType& dt = etn.getDType(); if (!dt.isSygus()) { // not a sygus datatype term @@ -1358,7 +1356,7 @@ void SygusExtension::registerSizeTerm(Node e, std::vector& lemmas) { // if it is variable agnostic, enforce top-level constraint that says no // variables occur pre-traversal at top-level - Node varList = Node::fromExpr(dt.getSygusVarList()); + Node varList = dt.getSygusVarList(); std::vector constraints; quantifiers::SygusTypeInfo& eti = d_tds->getTypeInfo(etn); for (const Node& v : varList) @@ -1672,7 +1670,7 @@ bool SygusExtension::checkValue(Node n, << std::endl; } TypeNode tn = n.getType(); - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); // ensure that the expected size bound is met @@ -1703,9 +1701,7 @@ bool SygusExtension::checkValue(Node n, } for( unsigned i=0; imkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(tn.toType(), i)), - n); + APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(tn, i), n); if (!checkValue(sel, vn[i], ind + 1, lemmas)) { return false; @@ -1719,14 +1715,15 @@ Node SygusExtension::getCurrentTemplate( Node n, std::map< TypeNode, int >& var_ TypeNode tn = n.getType(); IntMap::const_iterator it = d_testers.find( n ); Assert(it != d_testers.end()); - const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); int tindex = (*it).second; Assert(tindex >= 0); Assert(tindex < (int)dt.getNumConstructors()); std::vector< Node > children; - children.push_back( Node::fromExpr( dt[tindex].getConstructor() ) ); + children.push_back(dt[tindex].getConstructor()); for( unsigned i=0; imkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex].getSelectorInternal( tn.toType(), i ) ), n ); + Node sel = NodeManager::currentNM()->mkNode( + APPLY_SELECTOR_TOTAL, dt[tindex].getSelectorInternal(tn, i), n); Node cc = getCurrentTemplate( sel, var_count ); children.push_back( cc ); } diff --git a/src/theory/datatypes/sygus_simple_sym.cpp b/src/theory/datatypes/sygus_simple_sym.cpp index f1e8949af..21fb71bf7 100644 --- a/src/theory/datatypes/sygus_simple_sym.cpp +++ b/src/theory/datatypes/sygus_simple_sym.cpp @@ -153,8 +153,8 @@ class ReqTrie bool SygusSimpleSymBreak::considerArgKind( TypeNode tn, TypeNode tnp, Kind k, Kind pk, int arg) { - const Datatype& pdt = ((DatatypeType)(tnp).toType()).getDatatype(); - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& pdt = tnp.getDType(); + const DType& dt = tn.getDType(); quantifiers::SygusTypeInfo& ti = d_tds->getTypeInfo(tn); quantifiers::SygusTypeInfo& pti = d_tds->getTypeInfo(tnp); Assert(ti.hasKind(k)); @@ -178,7 +178,7 @@ bool SygusSimpleSymBreak::considerArgKind( // the argument types of the child must be the parent's type for (unsigned i = 0, nargs = dt[c].getNumArgs(); i < nargs; i++) { - TypeNode tn = TypeNode::fromType(dt[c].getArgType(i)); + TypeNode tn = dt[c].getArgType(i); if (tn != tnp) { return true; @@ -202,7 +202,7 @@ bool SygusSimpleSymBreak::considerArgKind( // negation normal form if (pk == k) { - rt.d_req_type = TypeNode::fromType(dt[c].getArgType(0)); + rt.d_req_type = dt[c].getArgType(0); } else { @@ -233,27 +233,25 @@ bool SygusSimpleSymBreak::considerArgKind( rt.d_req_kind = ITE; reqkc[1] = NOT; reqkc[2] = NOT; - rt.d_children[0].d_req_type = TypeNode::fromType(dt[c].getArgType(0)); + rt.d_children[0].d_req_type = dt[c].getArgType(0); } else if (k == LEQ || k == GT) { // (not (~ x y)) -----> (~ (+ y 1) x) rt.d_req_kind = k; rt.d_children[0].d_req_kind = PLUS; - rt.d_children[0].d_children[0].d_req_type = - TypeNode::fromType(dt[c].getArgType(1)); + rt.d_children[0].d_children[0].d_req_type = dt[c].getArgType(1); rt.d_children[0].d_children[1].d_req_const = NodeManager::currentNM()->mkConst(Rational(1)); - rt.d_children[1].d_req_type = TypeNode::fromType(dt[c].getArgType(0)); + rt.d_children[1].d_req_type = dt[c].getArgType(0); } else if (k == LT || k == GEQ) { // (not (~ x y)) -----> (~ y (+ x 1)) rt.d_req_kind = k; - rt.d_children[0].d_req_type = TypeNode::fromType(dt[c].getArgType(1)); + rt.d_children[0].d_req_type = dt[c].getArgType(1); rt.d_children[1].d_req_kind = PLUS; - rt.d_children[1].d_children[0].d_req_type = - TypeNode::fromType(dt[c].getArgType(0)); + rt.d_children[1].d_children[0].d_req_type = dt[c].getArgType(0); rt.d_children[1].d_children[1].d_req_const = NodeManager::currentNM()->mkConst(Rational(1)); } @@ -318,8 +316,7 @@ bool SygusSimpleSymBreak::considerArgKind( if (rk != UNDEFINED_KIND) { rt.d_children[i].d_req_kind = rk; - rt.d_children[i].d_children[0].d_req_type = - TypeNode::fromType(dt[c].getArgType(i)); + rt.d_children[i].d_children[0].d_req_type = dt[c].getArgType(i); } } } @@ -336,12 +333,10 @@ bool SygusSimpleSymBreak::considerArgKind( // (~ x (- y z)) ----> (~ (+ x z) y) // (~ (- y z) x) ----> (~ y (+ x z)) rt.d_req_kind = pk; - rt.d_children[arg].d_req_type = TypeNode::fromType(dt[c].getArgType(0)); + rt.d_children[arg].d_req_type = dt[c].getArgType(0); rt.d_children[oarg].d_req_kind = k == MINUS ? PLUS : BITVECTOR_PLUS; - rt.d_children[oarg].d_children[0].d_req_type = - TypeNode::fromType(pdt[pc].getArgType(oarg)); - rt.d_children[oarg].d_children[1].d_req_type = - TypeNode::fromType(dt[c].getArgType(1)); + rt.d_children[oarg].d_children[0].d_req_type = pdt[pc].getArgType(oarg); + rt.d_children[oarg].d_children[1].d_req_type = dt[c].getArgType(1); } else if (pk == PLUS || pk == BITVECTOR_PLUS) { @@ -350,11 +345,9 @@ bool SygusSimpleSymBreak::considerArgKind( rt.d_req_kind = pk == PLUS ? MINUS : BITVECTOR_SUB; int oarg = arg == 0 ? 1 : 0; rt.d_children[0].d_req_kind = pk; - rt.d_children[0].d_children[0].d_req_type = - TypeNode::fromType(pdt[pc].getArgType(oarg)); - rt.d_children[0].d_children[1].d_req_type = - TypeNode::fromType(dt[c].getArgType(0)); - rt.d_children[1].d_req_type = TypeNode::fromType(dt[c].getArgType(1)); + rt.d_children[0].d_children[0].d_req_type = pdt[pc].getArgType(oarg); + rt.d_children[0].d_children[1].d_req_type = dt[c].getArgType(0); + rt.d_children[1].d_req_type = dt[c].getArgType(1); } } else if (k == ITE) @@ -363,7 +356,7 @@ bool SygusSimpleSymBreak::considerArgKind( { // (o X (ite y z w) X') -----> (ite y (o X z X') (o X w X')) rt.d_req_kind = ITE; - rt.d_children[0].d_req_type = TypeNode::fromType(dt[c].getArgType(0)); + rt.d_children[0].d_req_type = dt[c].getArgType(0); unsigned n_args = pdt[pc].getNumArgs(); for (unsigned r = 1; r <= 2; r++) { @@ -372,13 +365,11 @@ bool SygusSimpleSymBreak::considerArgKind( { if ((int)q == arg) { - rt.d_children[r].d_children[q].d_req_type = - TypeNode::fromType(dt[c].getArgType(r)); + rt.d_children[r].d_children[q].d_req_type = dt[c].getArgType(r); } else { - rt.d_children[r].d_children[q].d_req_type = - TypeNode::fromType(pdt[pc].getArgType(q)); + rt.d_children[r].d_children[q].d_req_type = pdt[pc].getArgType(q); } } } @@ -391,9 +382,9 @@ bool SygusSimpleSymBreak::considerArgKind( { // (ite (not y) z w) -----> (ite y w z) rt.d_req_kind = ITE; - rt.d_children[0].d_req_type = TypeNode::fromType(dt[c].getArgType(0)); - rt.d_children[1].d_req_type = TypeNode::fromType(pdt[pc].getArgType(2)); - rt.d_children[2].d_req_type = TypeNode::fromType(pdt[pc].getArgType(1)); + rt.d_children[0].d_req_type = dt[c].getArgType(0); + rt.d_children[1].d_req_type = pdt[pc].getArgType(2); + rt.d_children[2].d_req_type = pdt[pc].getArgType(1); } } Trace("sygus-sb-debug") << "Consider sygus arg kind " << k << ", pk = " << pk @@ -425,7 +416,7 @@ bool SygusSimpleSymBreak::considerArgKind( bool SygusSimpleSymBreak::considerConst( TypeNode tn, TypeNode tnp, Node c, Kind pk, int arg) { - const Datatype& pdt = static_cast(tnp.toType()).getDatatype(); + const DType& pdt = tnp.getDType(); // child grammar-independent if (!considerConst(pdt, tnp, c, pk, arg)) { @@ -477,7 +468,7 @@ bool SygusSimpleSymBreak::considerConst( } bool SygusSimpleSymBreak::considerConst( - const Datatype& pdt, TypeNode tnp, Node c, Kind pk, int arg) + const DType& pdt, TypeNode tnp, Node c, Kind pk, int arg) { quantifiers::SygusTypeInfo& pti = d_tds->getTypeInfo(tnp); Assert(pti.hasKind(pk)); @@ -490,7 +481,7 @@ bool SygusSimpleSymBreak::considerConst( if (pdt[pc].getNumArgs() == 2) { int oarg = arg == 0 ? 1 : 0; - TypeNode otn = TypeNode::fromType(pdt[pc].getArgType(oarg)); + TypeNode otn = pdt[pc].getArgType(oarg); if (otn == tnp) { Trace("sygus-sb-simple") @@ -547,8 +538,8 @@ bool SygusSimpleSymBreak::considerConst( if (c == one_c && arg == 2) { rt.d_req_kind = STRING_CHARAT; - rt.d_children[0].d_req_type = TypeNode::fromType(pdt[pc].getArgType(0)); - rt.d_children[1].d_req_type = TypeNode::fromType(pdt[pc].getArgType(1)); + rt.d_children[0].d_req_type = pdt[pc].getArgType(0); + rt.d_children[1].d_req_type = pdt[pc].getArgType(1); } } if (!rt.empty()) @@ -576,13 +567,12 @@ int SygusSimpleSymBreak::solveForArgument(TypeNode tn, return -1; } -int SygusSimpleSymBreak::getFirstArgOccurrence(const DatatypeConstructor& c, +int SygusSimpleSymBreak::getFirstArgOccurrence(const DTypeConstructor& c, TypeNode tn) { for (unsigned i = 0, nargs = c.getNumArgs(); i < nargs; i++) { - TypeNode tni = TypeNode::fromType(c.getArgType(i)); - if (tni == tn) + if (c.getArgType(i) == tn) { return i; } diff --git a/src/theory/datatypes/sygus_simple_sym.h b/src/theory/datatypes/sygus_simple_sym.h index 815466d00..59eadce93 100644 --- a/src/theory/datatypes/sygus_simple_sym.h +++ b/src/theory/datatypes/sygus_simple_sym.h @@ -18,6 +18,7 @@ #define CVC4__THEORY__DATATYPES__SIMPLE_SYM_BREAK_H #include +#include "expr/dtype.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_util.h" @@ -91,13 +92,12 @@ class SygusSimpleSymBreak /** Pointer to the quantifiers term utility */ quantifiers::TermUtil* d_tutil; /** return the index of the first argument position of c that has type tn */ - int getFirstArgOccurrence(const DatatypeConstructor& c, TypeNode tn); + int getFirstArgOccurrence(const DTypeConstructor& c, TypeNode tn); /** * Helper function for consider const above, pdt is the datatype of the type * of tnp. */ - bool considerConst( - const Datatype& pdt, TypeNode tnp, Node c, Kind pk, int arg); + bool considerConst(const DType& pdt, TypeNode tnp, Node c, Kind pk, int arg); }; } // namespace datatypes diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 2d6aeae60..5e071c85c 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -19,6 +19,7 @@ #include "base/check.h" #include "expr/datatype.h" +#include "expr/dtype.h" #include "expr/kind.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" @@ -202,11 +203,12 @@ void TheoryDatatypes::check(Effort e) { //if there are more than 1 possible constructors for eqc if( !hasLabel( eqc, n ) ){ Trace("datatypes-debug") << "No constructor..." << std::endl; - Type tt = tn.toType(); - const Datatype& dt = ((DatatypeType)tt).getDatatype(); + TypeNode tt = tn; + const DType& dt = tt.getDType(); Trace("datatypes-debug") - << "Datatype " << dt << " is " << dt.isInterpretedFinite(tt) - << " " << dt.isRecursiveSingleton(tt) << std::endl; + << "Datatype " << dt.getName() << " is " + << dt.isInterpretedFinite(tt) << " " + << dt.isRecursiveSingleton(tt) << std::endl; bool continueProc = true; if( dt.isRecursiveSingleton( tt ) ){ Trace("datatypes-debug") << "Check recursive singleton..." << std::endl; @@ -224,7 +226,7 @@ void TheoryDatatypes::check(Effort e) { //otherwise, if the logic is quantified, under the assumption that all uninterpreted sorts have cardinality one, // infer the equality. for( unsigned i=0; i& pcons ){ TypeNode tn = n.getType(); - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); int lindex = getLabelIndex( eqc, n ); pcons.resize( dt.getNumConstructors(), lindex==-1 ); if( lindex!=-1 ){ @@ -1117,8 +1115,8 @@ void TheoryDatatypes::addTester( d_labels_tindex[n].push_back(ttindex); } n_lbl++; - - const Datatype& dt = ((DatatypeType)(t_arg.getType()).toType()).getDatatype(); + + const DType& dt = t_arg.getType().getDType(); Debug("datatypes-labels") << "Labels at " << n_lbl << " / " << dt.getNumConstructors() << std::endl; if( tpolarity ){ instantiate( eqc, n ); @@ -1315,11 +1313,11 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { use_s = s; } if( s.getKind()==kind::APPLY_SELECTOR_TOTAL ){ - Expr selectorExpr = s.getOperator().toExpr(); + Node selector = s.getOperator(); size_t constructorIndex = utils::indexOf(c.getOperator()); - const Datatype& dt = Datatype::datatypeOf(selectorExpr); - const DatatypeConstructor& dtc = dt[constructorIndex]; - int selectorIndex = dtc.getSelectorIndexInternal( selectorExpr ); + const DType& dt = utils::datatypeOf(selector); + const DTypeConstructor& dtc = dt[constructorIndex]; + int selectorIndex = dtc.getSelectorIndexInternal(selector); wrong = selectorIndex<0; //if( wrong ){ @@ -1549,8 +1547,8 @@ bool TheoryDatatypes::collectModelInfo(TheoryModel* m) Node eqc = nodes[index]; Node neqc; bool addCons = false; - Type tt = eqc.getType().toType(); - const Datatype& dt = ((DatatypeType)tt).getDatatype(); + TypeNode tt = eqc.getType(); + const DType& dt = tt.getDType(); if( !d_equalityEngine.hasTerm( eqc ) ){ Assert(false); }else{ @@ -1725,7 +1723,7 @@ void TheoryDatatypes::collectTerms( Node n ) { else if (nk == DT_HEIGHT_BOUND && n[1].getConst().isZero()) { std::vector children; - const Datatype& dt = n[0].getType().getDatatype(); + const DType& dt = n[0].getType().getDType(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { if (utils::isNullaryConstructor(dt[i])) @@ -1749,7 +1747,8 @@ void TheoryDatatypes::collectTerms( Node n ) { } } -Node TheoryDatatypes::getInstantiateCons( Node n, const Datatype& dt, int index ){ +Node TheoryDatatypes::getInstantiateCons(Node n, const DType& dt, int index) +{ std::map< int, Node >::iterator it = d_inst_map[n].find( index ); if( it!=d_inst_map[n].end() ){ return it->second; @@ -1791,7 +1790,7 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){ exp = getLabel(n); tt = exp[0]; } - const Datatype& dt = ((DatatypeType)(tt.getType()).toType()).getDatatype(); + const DType& dt = tt.getType().getDType(); // instantiate this equivalence class eqc->d_inst = true; Node tt_cons = getInstantiateCons(tt, dt, index); @@ -2087,7 +2086,7 @@ bool TheoryDatatypes::mustCommunicateFact( Node n, Node exp ){ if( !tn.isDatatype() ){ addLemma = true; }else{ - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); addLemma = dt.involvesExternalType(); } }else if( n.getKind()==LEQ || n.getKind()==OR ){ diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index ba09ce89e..a878647bc 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -354,7 +354,7 @@ private: /** collect terms */ void collectTerms( Node n ); /** get instantiate cons */ - Node getInstantiateCons( Node n, const Datatype& dt, int index ); + Node getInstantiateCons(Node n, const DType& dt, int index); /** check instantiate */ void instantiate( EqcInfo* eqc, Node n ); /** must communicate fact */ diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index 60841a5dd..97e67e7fa 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -19,6 +19,7 @@ #ifndef CVC4__THEORY__DATATYPES__THEORY_DATATYPES_TYPE_RULES_H #define CVC4__THEORY__DATATYPES__THEORY_DATATYPES_TYPE_RULES_H +#include "expr/dtype.h" #include "expr/type_matcher.h" #include "theory/datatypes/theory_datatypes_utils.h" @@ -104,22 +105,6 @@ struct DatatypeConstructorTypeRule { return false; } } - //if we support subtyping for tuples, enable this - /* - //check whether it is in normal form? - TypeNode tn = n.getType(); - if( tn.isTuple() ){ - const Datatype& dt = tn.getDatatype(); - //may be the wrong constructor, if children types are subtypes - for( unsigned i=0; i patIndices; bool patHasVariable = false; @@ -510,10 +495,10 @@ class MatchTypeRule throw TypeCheckingExceptionPrivate( n, "unexpected kind of term in pattern in match"); } - const Datatype& pdt = patType.getDatatype(); + const DType& pdt = patType.getDType(); // compare datatypes instead of the types to catch parametric case, // where the pattern has parametric type. - if (hdt != pdt) + if (hdt.getTypeNode() != pdt.getTypeNode()) { std::stringstream ss; ss << "pattern of a match case does not match the head type in match"; diff --git a/src/theory/datatypes/theory_datatypes_utils.cpp b/src/theory/datatypes/theory_datatypes_utils.cpp index d2833a852..2fe8a99fe 100644 --- a/src/theory/datatypes/theory_datatypes_utils.cpp +++ b/src/theory/datatypes/theory_datatypes_utils.cpp @@ -16,6 +16,7 @@ #include "theory/datatypes/theory_datatypes_utils.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" #include "theory/evaluator.h" @@ -28,7 +29,7 @@ namespace theory { namespace datatypes { namespace utils { -Node applySygusArgs(const Datatype& dt, +Node applySygusArgs(const DType& dt, Node op, Node n, const std::vector& args) @@ -37,7 +38,7 @@ Node applySygusArgs(const Datatype& dt, { Assert(n.hasAttribute(SygusVarNumAttribute())); int vn = n.getAttribute(SygusVarNumAttribute()); - Assert(Node::fromExpr(dt.getSygusVarList())[vn] == n); + Assert(dt.getSygusVarList()[vn] == n); return args[vn]; } // n is an application of operator op. @@ -81,7 +82,7 @@ Node applySygusArgs(const Datatype& dt, } // do the full substitution std::vector vars; - Node bvl = Node::fromExpr(dt.getSygusVarList()); + Node bvl = dt.getSygusVarList(); for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++) { vars.push_back(bvl[i]); @@ -116,7 +117,7 @@ Kind getOperatorKindForSygusBuiltin(Node op) return UNDEFINED_KIND; } -Node mkSygusTerm(const Datatype& dt, +Node mkSygusTerm(const DType& dt, unsigned i, const std::vector& children, bool doBetaReduction) @@ -126,7 +127,7 @@ Node mkSygusTerm(const Datatype& dt, Assert(i < dt.getNumConstructors()); Assert(dt.isSygus()); Assert(!dt[i].getSygusOp().isNull()); - Node op = Node::fromExpr(dt[i].getSygusOp()); + Node op = dt[i].getSygusOp(); return mkSygusTerm(op, children, doBetaReduction); } @@ -203,24 +204,22 @@ Node mkSygusTerm(Node op, } /** get instantiate cons */ -Node getInstCons(Node n, const Datatype& dt, int index) +Node getInstCons(Node n, const DType& dt, int index) { Assert(index >= 0 && index < (int)dt.getNumConstructors()); std::vector children; NodeManager* nm = NodeManager::currentNM(); - children.push_back(Node::fromExpr(dt[index].getConstructor())); - Type t = n.getType().toType(); + children.push_back(dt[index].getConstructor()); + TypeNode tn = n.getType(); for (unsigned i = 0, nargs = dt[index].getNumArgs(); i < nargs; i++) { - Node nc = nm->mkNode(APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[index].getSelectorInternal(t, i)), - n); + Node nc = nm->mkNode( + APPLY_SELECTOR_TOTAL, dt[index].getSelectorInternal(tn, i), n); children.push_back(nc); } Node n_ic = nm->mkNode(APPLY_CONSTRUCTOR, children); if (dt.isParametric()) { - TypeNode tn = TypeNode::fromType(t); // add type ascription for ambiguous constructor types if (!n_ic.getType().isComparableTo(tn)) { @@ -229,12 +228,11 @@ Node getInstCons(Node n, const Datatype& dt, int index) << n.getType() << std::endl; Debug("datatypes-parametric") << "Constructor is " << dt[index] << std::endl; - Type tspec = - dt[index].getSpecializedConstructorType(n.getType().toType()); + TypeNode tspec = dt[index].getSpecializedConstructorType(n.getType()); Debug("datatypes-parametric") << "Type specification is " << tspec << std::endl; children[0] = nm->mkNode(APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(tspec)), + nm->mkConst(AscriptionType(tspec.toType())), children[0]); n_ic = nm->mkNode(APPLY_CONSTRUCTOR, children); Assert(n_ic.getType() == tn); @@ -245,18 +243,17 @@ Node getInstCons(Node n, const Datatype& dt, int index) return n_ic; } -int isInstCons(Node t, Node n, const Datatype& dt) +int isInstCons(Node t, Node n, const DType& dt) { if (n.getKind() == APPLY_CONSTRUCTOR) { int index = indexOf(n.getOperator()); - const DatatypeConstructor& c = dt[index]; - Type nt = n.getType().toType(); + const DTypeConstructor& c = dt[index]; + TypeNode tn = n.getType(); for (unsigned i = 0, size = n.getNumChildren(); i < size; i++) { if (n[i].getKind() != APPLY_SELECTOR_TOTAL - || n[i].getOperator() != Node::fromExpr(c.getSelectorInternal(nt, i)) - || n[i][0] != t) + || n[i].getOperator() != c.getSelectorInternal(tn, i) || n[i][0] != t) { return -1; } @@ -285,31 +282,29 @@ int isTester(Node n) return -1; } -struct DtIndexAttributeId -{ -}; -typedef expr::Attribute DtIndexAttribute; +size_t indexOf(Node n) { return DType::indexOf(n); } + +size_t cindexOf(Node n) { return DType::cindexOf(n); } -unsigned indexOf(Node n) +const DType& datatypeOf(Node n) { - if (!n.hasAttribute(DtIndexAttribute())) - { - Assert(n.getType().isConstructor() || n.getType().isTester() - || n.getType().isSelector()); - unsigned index = Datatype::indexOfInternal(n.toExpr()); - n.setAttribute(DtIndexAttribute(), index); - return index; + TypeNode t = n.getType(); + switch (t.getKind()) + { + case CONSTRUCTOR_TYPE: return t[t.getNumChildren() - 1].getDType(); + case SELECTOR_TYPE: + case TESTER_TYPE: return t[0].getDType(); + default: + Unhandled() << "arg must be a datatype constructor, selector, or tester"; } - return n.getAttribute(DtIndexAttribute()); } -Node mkTester(Node n, int i, const Datatype& dt) +Node mkTester(Node n, int i, const DType& dt) { - return NodeManager::currentNM()->mkNode( - APPLY_TESTER, Node::fromExpr(dt[i].getTester()), n); + return NodeManager::currentNM()->mkNode(APPLY_TESTER, dt[i].getTester(), n); } -Node mkSplit(Node n, const Datatype& dt) +Node mkSplit(Node n, const DType& dt) { std::vector splits; for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) @@ -334,7 +329,7 @@ bool isNullaryApplyConstructor(Node n) return true; } -bool isNullaryConstructor(const DatatypeConstructor& c) +bool isNullaryConstructor(const DTypeConstructor& c) { for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { @@ -433,7 +428,7 @@ Node sygusToBuiltin(Node n) { Node ret = cur; Assert(cur.getKind() == APPLY_CONSTRUCTOR); - const Datatype& dt = cur.getType().getDatatype(); + const DType& dt = cur.getType().getDType(); // Non sygus-datatype terms are also themselves. Notice we treat the // case of non-sygus datatypes this way since it avoids computing // the type / datatype of the node in the pre-traversal above. The @@ -555,7 +550,7 @@ Node sygusToBuiltinEval(Node n, const std::vector& args) { Node ret = cur; Assert(cur.getKind() == APPLY_CONSTRUCTOR); - const Datatype& dt = cur.getType().getDatatype(); + const DType& dt = cur.getType().getDType(); // non sygus-datatype terms are also themselves if (dt.isSygus()) { diff --git a/src/theory/datatypes/theory_datatypes_utils.h b/src/theory/datatypes/theory_datatypes_utils.h index 46a6d56be..b23302276 100644 --- a/src/theory/datatypes/theory_datatypes_utils.h +++ b/src/theory/datatypes/theory_datatypes_utils.h @@ -21,6 +21,7 @@ #include +#include "expr/dtype.h" #include "expr/node.h" #include "expr/node_manager_attributes.h" @@ -77,14 +78,14 @@ namespace utils { * This returns the term C( sel^{C,1}( n ), ..., sel^{C,m}( n ) ), * where C is the index^{th} constructor of datatype dt. */ -Node getInstCons(Node n, const Datatype& dt, int index); +Node getInstCons(Node n, const DType& dt, int index); /** is instantiation cons * * If this method returns a value >=0, then that value, call it index, * is such that n = C( sel^{C,1}( t ), ..., sel^{C,m}( t ) ), * where C is the index^{th} constructor of dt. */ -int isInstCons(Node t, Node n, const Datatype& dt); +int isInstCons(Node t, Node n, const DType& dt); /** is tester * * This method returns a value >=0 if n is a tester predicate. The return @@ -100,19 +101,28 @@ int isTester(Node n); * index of a selector in its constructor. (Zero is always the * first index.) */ -unsigned indexOf(Node n); +size_t indexOf(Node n); +/** + * Get the index of constructor corresponding to selector. + * (Zero is always the first index.) + */ +size_t cindexOf(Node n); +/** + * Get the datatype of n. + */ +const DType& datatypeOf(Node n); /** make tester is-C( n ), where C is the i^{th} constructor of dt */ -Node mkTester(Node n, int i, const Datatype& dt); +Node mkTester(Node n, int i, const DType& dt); /** make tester split * * Returns the formula (OR is-C1( n ) ... is-Ck( n ) ), where C1...Ck * are the constructors of n's type (dt). */ -Node mkSplit(Node n, const Datatype& dt); +Node mkSplit(Node n, const DType& dt); /** returns true iff n is a constructor term with no datatype children */ bool isNullaryApplyConstructor(Node n); /** returns true iff c is a constructor with no datatype children */ -bool isNullaryConstructor(const DatatypeConstructor& c); +bool isNullaryConstructor(const DTypeConstructor& c); /** check clash * * This method returns true if and only if n1 and n2 have a skeleton that has @@ -143,7 +153,7 @@ Kind getOperatorKindForSygusBuiltin(Node op); * encodes. If doBetaReduction is true, then lambdas are eagerly eliminated * via beta reduction. */ -Node mkSygusTerm(const Datatype& dt, +Node mkSygusTerm(const DType& dt, unsigned i, const std::vector& children, bool doBetaReduction = true); @@ -181,7 +191,7 @@ Node mkSygusTerm(Node op, * to cache the results of whether the evaluation of this constructor needs * a substitution over the formal argument list of the function-to-synthesize. */ -Node applySygusArgs(const Datatype& dt, +Node applySygusArgs(const DType& dt, Node op, Node n, const std::vector& args); diff --git a/src/theory/datatypes/type_enumerator.cpp b/src/theory/datatypes/type_enumerator.cpp index 5de04a9c3..af686ded0 100644 --- a/src/theory/datatypes/type_enumerator.cpp +++ b/src/theory/datatypes/type_enumerator.cpp @@ -119,7 +119,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ { Debug("dt-enum-debug") << "Look at constructor " << (index - d_has_debruijn) << std::endl; - DatatypeConstructor ctor = d_datatype[index - d_has_debruijn]; + const DTypeConstructor& ctor = d_datatype[index - d_has_debruijn]; Debug("dt-enum-debug") << "Check last term..." << std::endl; // we first check if the last argument (which is forced to make sum of // iterated arguments equal to d_size_limit) is defined @@ -138,14 +138,13 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ } Debug("dt-enum-debug") << "Get constructor..." << std::endl; NodeBuilder<> b(kind::APPLY_CONSTRUCTOR); - Type typ; if (d_datatype.isParametric()) { - typ = ctor.getSpecializedConstructorType(d_type.toType()); - b << NodeManager::currentNM()->mkNode( - kind::APPLY_TYPE_ASCRIPTION, - NodeManager::currentNM()->mkConst(AscriptionType(typ)), - Node::fromExpr(ctor.getConstructor())); + NodeManager* nm = NodeManager::currentNM(); + TypeNode typ = ctor.getSpecializedConstructorType(d_type); + b << nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType(typ.toType())), + ctor.getConstructor()); } else { @@ -199,8 +198,8 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ Debug("dt-enum") << "datatype is kind " << d_type.getKind() << std::endl; Debug("dt-enum") << "datatype is " << d_type << std::endl; Debug("dt-enum") << "properties : " << d_datatype.isCodatatype() << " " - << d_datatype.isRecursiveSingleton(d_type.toType()); - Debug("dt-enum") << " " << d_datatype.isInterpretedFinite(d_type.toType()) + << d_datatype.isRecursiveSingleton(d_type); + Debug("dt-enum") << " " << d_datatype.isInterpretedFinite(d_type) << std::endl; if (d_datatype.isCodatatype() && hasCyclesDt(d_datatype)) @@ -222,7 +221,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ // this datatype. Since datatypes can not be embedded in non-datatype // types (e.g. (Array D D) cannot be a subfield type of datatype D), this // call is guaranteed to avoid infinite recursion. - d_zeroTerm = Node::fromExpr(d_datatype.mkGroundValue(d_type.toType())); + d_zeroTerm = d_datatype.mkGroundValue(d_type); d_zeroTermActive = true; Debug("dt-enum-debug") << "done : " << d_zeroTerm << std::endl; Assert(d_zeroTerm.getKind() == kind::APPLY_CONSTRUCTOR); @@ -235,22 +234,22 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ d_sel_types.push_back(std::vector()); d_sel_index.push_back(std::vector()); d_sel_sum.push_back(-1); - DatatypeConstructor ctor = d_datatype[i]; - Type typ; + const DTypeConstructor& ctor = d_datatype[i]; + TypeNode typ; if (d_datatype.isParametric()) { - typ = ctor.getSpecializedConstructorType(d_type.toType()); + typ = ctor.getSpecializedConstructorType(d_type); } for (unsigned a = 0; a < ctor.getNumArgs(); ++a) { TypeNode tn; if (d_datatype.isParametric()) { - tn = TypeNode::fromType(typ)[a]; + tn = typ[a]; } else { - tn = Node::fromExpr(ctor[a].getSelector()).getType()[1]; + tn = ctor[a].getSelector().getType()[1]; } d_sel_types.back().push_back(tn); d_sel_index.back().push_back(0); @@ -309,7 +308,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ // or other cases if (prevSize == d_size_limit || (d_size_limit == 0 && d_datatype.isCodatatype()) - || !d_datatype.isInterpretedFinite(d_type.toType())) + || !d_datatype.isInterpretedFinite(d_type)) { d_size_limit++; d_ctor = 0; diff --git a/src/theory/datatypes/type_enumerator.h b/src/theory/datatypes/type_enumerator.h index 6f7fc4286..ece332bc9 100644 --- a/src/theory/datatypes/type_enumerator.h +++ b/src/theory/datatypes/type_enumerator.h @@ -19,11 +19,12 @@ #ifndef CVC4__THEORY__DATATYPES__TYPE_ENUMERATOR_H #define CVC4__THEORY__DATATYPES__TYPE_ENUMERATOR_H -#include "theory/type_enumerator.h" -#include "expr/type_node.h" -#include "expr/type.h" +#include "expr/dtype.h" #include "expr/kind.h" +#include "expr/type.h" +#include "expr/type_node.h" #include "options/quantifiers_options.h" +#include "theory/type_enumerator.h" namespace CVC4 { namespace theory { @@ -34,7 +35,7 @@ class DatatypesEnumerator : public TypeEnumeratorBase { /** type properties */ TypeEnumeratorProperties * d_tep; /** The datatype we're enumerating */ - const Datatype& d_datatype; + const DType& d_datatype; /** extra cons */ unsigned d_has_debruijn; /** type */ @@ -62,12 +63,13 @@ class DatatypesEnumerator : public TypeEnumeratorBase { /** child */ bool d_child_enum; - bool hasCyclesDt( const Datatype& dt ) { - return dt.isRecursiveSingleton( d_type.toType() ) || !dt.isFinite( d_type.toType() ); + bool hasCyclesDt(const DType& dt) + { + return dt.isRecursiveSingleton(d_type) || !dt.isFinite(d_type); } bool hasCycles( TypeNode tn ){ if( tn.isDatatype() ){ - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); return hasCyclesDt( dt ); }else{ return false; @@ -86,7 +88,7 @@ class DatatypesEnumerator : public TypeEnumeratorBase { DatatypesEnumerator(TypeNode type, TypeEnumeratorProperties* tep = nullptr) : TypeEnumeratorBase(type), d_tep(tep), - d_datatype(DatatypeType(type.toType()).getDatatype()), + d_datatype(type.getDType()), d_type(type), d_ctor(0), d_zeroTermActive(false) @@ -99,7 +101,7 @@ class DatatypesEnumerator : public TypeEnumeratorBase { TypeEnumeratorProperties* tep = nullptr) : TypeEnumeratorBase(type), d_tep(tep), - d_datatype(DatatypeType(type.toType()).getDatatype()), + d_datatype(type.getDType()), d_type(type), d_ctor(0), d_zeroTermActive(false) diff --git a/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp index 6c4f2c620..9fd682aaf 100644 --- a/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_dt_instantiator.cpp @@ -14,7 +14,9 @@ #include "theory/quantifiers/cegqi/ceg_dt_instantiator.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" +#include "theory/datatypes/theory_datatypes_utils.h" using namespace std; using namespace CVC4::kind; @@ -57,16 +59,14 @@ bool DtInstantiator::processEqualTerms(CegInstantiator* ci, << "...try based on constructor term " << n << std::endl; std::vector children; children.push_back(n.getOperator()); - const Datatype& dt = - static_cast(d_type.toType()).getDatatype(); - unsigned cindex = Datatype::indexOf(n.getOperator().toExpr()); + const DType& dt = d_type.getDType(); + unsigned cindex = datatypes::utils::indexOf(n.getOperator()); // now must solve for selectors applied to pv for (unsigned j = 0, nargs = dt[cindex].getNumArgs(); j < nargs; j++) { - Node c = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(d_type.toType(), j)), - pv); + Node c = nm->mkNode(APPLY_SELECTOR_TOTAL, + dt[cindex].getSelectorInternal(d_type, j), + pv); ci->pushStackVariable(c); children.push_back(c); } @@ -146,15 +146,13 @@ Node DtInstantiator::solve_dt(Node v, Node a, Node b, Node sa, Node sb) else { NodeManager* nm = NodeManager::currentNM(); - unsigned cindex = Datatype::indexOf(a.getOperator().toExpr()); + unsigned cindex = DType::indexOf(a.getOperator().toExpr()); TypeNode tn = a.getType(); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); for (unsigned i = 0, nchild = a.getNumChildren(); i < nchild; i++) { Node nn = nm->mkNode( - APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(tn.toType(), i)), - sb); + APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(tn, i), sb); Node s = solve_dt(v, a[i], Node::null(), sa[i], nn); if (!s.isNull()) { diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp index bed382e28..1d4a23af1 100644 --- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp @@ -341,13 +341,12 @@ CegHandledStatus CegInstantiator::isCbqiSort( // we initialize to handled, we remain handled as long as all subfields // of this datatype are not unhandled. ret = CEG_HANDLED; - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode crange = TypeNode::fromType( - static_cast(dt[i][j].getType()).getRangeType()); + TypeNode crange = dt[i].getArgType(j); CegHandledStatus cret = isCbqiSort(crange, visited, qe); if (cret == CEG_UNHANDLED) { @@ -520,15 +519,12 @@ void CegInstantiator::registerTheoryIds(TypeNode tn, registerTheoryId(tid); if (tn.isDatatype()) { - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); for (unsigned i = 0; i < dt.getNumConstructors(); i++) { for (unsigned j = 0; j < dt[i].getNumArgs(); j++) { - registerTheoryIds( - TypeNode::fromType( - ((SelectorType)dt[i][j].getType()).getRangeType()), - visited); + registerTheoryIds(dt[i].getArgType(j), visited); } } } diff --git a/src/theory/quantifiers/ematching/candidate_generator.cpp b/src/theory/quantifiers/ematching/candidate_generator.cpp index 3a075ec8a..8e09ef6a2 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.cpp +++ b/src/theory/quantifiers/ematching/candidate_generator.cpp @@ -13,6 +13,7 @@ **/ #include "theory/quantifiers/ematching/candidate_generator.h" +#include "expr/dtype.h" #include "options/quantifiers_options.h" #include "theory/quantifiers/inst_match.h" #include "theory/quantifiers/instantiate.h" @@ -208,7 +209,7 @@ CandidateGeneratorConsExpand::CandidateGeneratorConsExpand( : CandidateGeneratorQE(qe, mpat) { Assert(mpat.getKind() == APPLY_CONSTRUCTOR); - d_mpat_type = static_cast(mpat.getType().toType()); + d_mpat_type = mpat.getType(); } void CandidateGeneratorConsExpand::reset(Node eqc) @@ -222,7 +223,7 @@ void CandidateGeneratorConsExpand::reset(Node eqc) { d_eqc = eqc; d_mode = cand_term_ident; - Assert(d_eqc.getType().toType() == d_mpat_type); + Assert(d_eqc.getType() == d_mpat_type); } } @@ -237,15 +238,13 @@ Node CandidateGeneratorConsExpand::getNextCandidate() // expand it NodeManager* nm = NodeManager::currentNM(); std::vector children; - const Datatype& dt = d_mpat_type.getDatatype(); + const DType& dt = d_mpat_type.getDType(); Assert(dt.getNumConstructors() == 1); children.push_back(d_op); for (unsigned i = 0, nargs = dt[0].getNumArgs(); i < nargs; i++) { - Node sel = - nm->mkNode(APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[0].getSelectorInternal(d_mpat_type, i)), - curr); + Node sel = nm->mkNode( + APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(d_mpat_type, i), curr); children.push_back(sel); } return nm->mkNode(APPLY_CONSTRUCTOR, children); diff --git a/src/theory/quantifiers/ematching/candidate_generator.h b/src/theory/quantifiers/ematching/candidate_generator.h index 8cff12477..51c5ffa0b 100644 --- a/src/theory/quantifiers/ematching/candidate_generator.h +++ b/src/theory/quantifiers/ematching/candidate_generator.h @@ -203,7 +203,7 @@ class CandidateGeneratorConsExpand : public CandidateGeneratorQE protected: /** the (datatype) type of the input match pattern */ - DatatypeType d_mpat_type; + TypeNode d_mpat_type; /** we don't care about the operator of n */ bool isLegalOpCandidate(Node n) override; }; diff --git a/src/theory/quantifiers/ematching/inst_match_generator.cpp b/src/theory/quantifiers/ematching/inst_match_generator.cpp index e639dc446..6fdd6d67a 100644 --- a/src/theory/quantifiers/ematching/inst_match_generator.cpp +++ b/src/theory/quantifiers/ematching/inst_match_generator.cpp @@ -17,6 +17,7 @@ #include "expr/datatype.h" #include "options/datatypes_options.h" #include "options/quantifiers_options.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/ematching/candidate_generator.h" #include "theory/quantifiers/ematching/trigger.h" #include "theory/quantifiers/instantiate.h" @@ -203,7 +204,7 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector< { // 1-constructors have a trivial way of generating candidates in a // given equivalence class - const Datatype& dt = d_match_pattern.getType().getDatatype(); + const DType& dt = d_match_pattern.getType().getDType(); if (dt.getNumConstructors() == 1) { d_cg = new inst::CandidateGeneratorConsExpand(qe, d_match_pattern); @@ -226,11 +227,11 @@ void InstMatchGenerator::initialize( Node q, QuantifiersEngine* qe, std::vector< } }else if( d_match_pattern.getKind()==INST_CONSTANT ){ if( d_pattern.getKind()==APPLY_SELECTOR_TOTAL ){ - Expr selectorExpr = qe->getTermDatabase()->getMatchOperator( d_pattern ).toExpr(); - size_t selectorIndex = Datatype::cindexOf(selectorExpr); - const Datatype& dt = Datatype::datatypeOf(selectorExpr); - const DatatypeConstructor& c = dt[selectorIndex]; - Node cOp = Node::fromExpr(c.getConstructor()); + Node selectorExpr = qe->getTermDatabase()->getMatchOperator(d_pattern); + size_t selectorIndex = datatypes::utils::cindexOf(selectorExpr); + const DType& dt = datatypes::utils::datatypeOf(selectorExpr); + const DTypeConstructor& c = dt[selectorIndex]; + Node cOp = c.getConstructor(); Trace("inst-match-gen") << "Purify dt trigger " << d_pattern << ", will match terms of op " << cOp << std::endl; d_cg = new inst::CandidateGeneratorQE( qe, cOp ); }else{ diff --git a/src/theory/quantifiers/fmf/bounded_integers.cpp b/src/theory/quantifiers/fmf/bounded_integers.cpp index c6800b092..cfff64f15 100644 --- a/src/theory/quantifiers/fmf/bounded_integers.cpp +++ b/src/theory/quantifiers/fmf/bounded_integers.cpp @@ -19,6 +19,7 @@ #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/first_order_model.h" #include "theory/quantifiers/fmf/model_engine.h" #include "theory/quantifiers/term_enumeration.h" @@ -737,14 +738,17 @@ Node BoundedIntegers::matchBoundVar( Node v, Node t, Node e ){ return Node::null(); } } - const Datatype& dt = Datatype::datatypeOf( t.getOperator().toExpr() ); - unsigned index = Datatype::indexOf( t.getOperator().toExpr() ); + NodeManager* nm = NodeManager::currentNM(); + const DType& dt = datatypes::utils::datatypeOf(t.getOperator()); + unsigned index = datatypes::utils::indexOf(t.getOperator()); for( unsigned i=0; imkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[index].getSelectorInternal( e.getType().toType(), i ) ), e ); + Node se = nm->mkNode(APPLY_SELECTOR_TOTAL, + dt[index].getSelectorInternal(e.getType(), i), + e); u = matchBoundVar( v, t[i], se ); } if( !u.isNull() ){ diff --git a/src/theory/quantifiers/quant_split.cpp b/src/theory/quantifiers/quant_split.cpp index e425cd345..32bd2b0e8 100644 --- a/src/theory/quantifiers/quant_split.cpp +++ b/src/theory/quantifiers/quant_split.cpp @@ -47,16 +47,19 @@ void QuantDSplit::checkOwnership(Node q) for( unsigned i=0; iisFiniteBound( q, q[0][i] ) ){ - if (dt.isInterpretedFinite(tn.toType())) + if (dt.isInterpretedFinite(tn)) { // split if goes from being unhandled -> handled by finite // instantiation. An example is datatypes with uninterpreted sort @@ -144,20 +147,20 @@ void QuantDSplit::check(Theory::Effort e, QEffort quant_e) TypeNode tn = svar.getType(); Assert(tn.isDatatype()); std::vector cons; - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++) { std::vector vars; for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++) { - TypeNode tns = TypeNode::fromType(dt[j][k].getRangeType()); + TypeNode tns = dt[j][k].getRangeType(); Node v = nm->mkBoundVar(tns); vars.push_back(v); } std::vector bvs_cmb; bvs_cmb.insert(bvs_cmb.end(), bvs.begin(), bvs.end()); bvs_cmb.insert(bvs_cmb.end(), vars.begin(), vars.end()); - vars.insert(vars.begin(), Node::fromExpr(dt[j].getConstructor())); + vars.insert(vars.begin(), dt[j].getConstructor()); Node c = nm->mkNode(kind::APPLY_CONSTRUCTOR, vars); TNode ct = c; Node body = q[1].substitute(svar, ct); diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index 0039ec845..8d65523e1 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -14,9 +14,11 @@ #include "theory/quantifiers/quantifiers_rewriter.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/bv_inverter.h" #include "theory/quantifiers/ematching/trigger.h" #include "theory/quantifiers/quantifiers_attributes.h" @@ -308,8 +310,8 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node Trace("quantifiers-rewrite-ite-debug") << "...condition already set " << itp->second << std::endl; computeDtTesterIteSplit( n[ itp->second==n[0] ? 1 : 2 ], pcons, ncons, conj ); }else{ - Expr testerExpr = n[0].getOperator().toExpr(); - int index = Datatype::indexOf( testerExpr ); + Node tester = n[0].getOperator(); + int index = datatypes::utils::indexOf(tester); std::map< int, Node >::iterator itn = ncons[x].find( index ); if( itn!=ncons[x].end() ){ Trace("quantifiers-rewrite-ite-debug") << "...condition negated " << itn->second << std::endl; @@ -328,6 +330,7 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node } } }else{ + NodeManager* nm = NodeManager::currentNM(); Trace("quantifiers-rewrite-ite-debug") << "Return value : " << n << std::endl; std::vector< Node > children; children.push_back( n ); @@ -343,7 +346,7 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node //only if we haven't settled on a positive tester if( std::find( vars.begin(), vars.end(), x )==vars.end() ){ //check if we have exhausted all options but one - const Datatype& dt = DatatypeType(x.getType().toType()).getDatatype(); + const DType& dt = x.getType().getDType(); std::vector< Node > nchildren; int pos_cons = -1; for( int i=0; i<(int)dt.getNumConstructors(); i++ ){ @@ -355,9 +358,8 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node } } if( pos_cons>=0 ){ - const DatatypeConstructor& c = dt[pos_cons]; - Expr tester = c.getTester(); - children.push_back( NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr( tester ), x ).negate() ); + Node tester = dt[pos_cons].getTester(); + children.push_back(nm->mkNode(APPLY_TESTER, tester, x).negate()); }else{ children.insert( children.end(), nchildren.begin(), nchildren.end() ); } @@ -454,20 +456,21 @@ void setEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::v } if( addEntailedCond( n, pol, currCond, new_cond, conflict ) ){ if( n.getKind()==APPLY_TESTER ){ - const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr()); - unsigned index = Datatype::indexOf(n.getOperator().toExpr()); + NodeManager* nm = NodeManager::currentNM(); + const DType& dt = datatypes::utils::datatypeOf(n.getOperator()); + unsigned index = datatypes::utils::indexOf(n.getOperator()); Assert(dt.getNumConstructors() > 1); if( pol ){ for( unsigned i=0; imkNode( APPLY_TESTER, Node::fromExpr( dt[i].getTester() ), n[0] ); + Node t = nm->mkNode(APPLY_TESTER, dt[i].getTester(), n[0]); addEntailedCond( t, false, currCond, new_cond, conflict ); } } }else{ if( dt.getNumConstructors()==2 ){ int oindex = 1-index; - Node t = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[oindex].getTester() ), n[0] ); + Node t = nm->mkNode(APPLY_TESTER, dt[oindex].getTester(), n[0]); addEntailedCond( t, true, currCond, new_cond, conflict ); } } @@ -1011,16 +1014,16 @@ bool QuantifiersRewriter::getVarElimLit(Node lit, if (ita != args.end()) { vars.push_back(lit[0]); - Expr testerExpr = lit.getOperator().toExpr(); - int index = Datatype::indexOf(testerExpr); - const Datatype& dt = Datatype::datatypeOf(testerExpr); - const DatatypeConstructor& c = dt[index]; + Node tester = lit.getOperator(); + int index = datatypes::utils::indexOf(tester); + const DType& dt = datatypes::utils::datatypeOf(tester); + const DTypeConstructor& c = dt[index]; std::vector newChildren; - newChildren.push_back(Node::fromExpr(c.getConstructor())); + newChildren.push_back(c.getConstructor()); std::vector newVars; for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { - TypeNode tn = TypeNode::fromType(c[j].getRangeType()); + TypeNode tn = c[j].getRangeType(); Node v = nm->mkBoundVar(tn); newChildren.push_back(v); newVars.push_back(v); @@ -1081,8 +1084,8 @@ bool QuantifiersRewriter::getVarElimLit(Node lit, { Trace("var-elim-dt") << "Expand datatype variable based on : " << lit << std::endl; - Expr testerExpr = lit.getOperator().toExpr(); - unsigned index = Datatype::indexOf(testerExpr); + Node tester = lit.getOperator(); + unsigned index = datatypes::utils::indexOf(tester); Node s = datatypeExpand(index, lit[0], args); if (!s.isNull()) { @@ -1179,16 +1182,15 @@ Node QuantifiersRewriter::datatypeExpand(unsigned index, { return Node::null(); } - const Datatype& dt = - static_cast(v.getType().toType()).getDatatype(); + const DType& dt = v.getType().getDType(); Assert(index < dt.getNumConstructors()); - const DatatypeConstructor& c = dt[index]; + const DTypeConstructor& c = dt[index]; std::vector newChildren; - newChildren.push_back(Node::fromExpr(c.getConstructor())); + newChildren.push_back(c.getConstructor()); std::vector newVars; for (unsigned j = 0, nargs = c.getNumArgs(); j < nargs; j++) { - TypeNode tn = TypeNode::fromType(c.getArgType(j)); + TypeNode tn = c.getArgType(j); Node vn = NodeManager::currentNM()->mkBoundVar(tn); newChildren.push_back(vn); newVars.push_back(vn); diff --git a/src/theory/quantifiers/skolemize.cpp b/src/theory/quantifiers/skolemize.cpp index 1d2b869c4..a83303454 100644 --- a/src/theory/quantifiers/skolemize.cpp +++ b/src/theory/quantifiers/skolemize.cpp @@ -73,8 +73,8 @@ Node Skolemize::getSkolemConstant(Node q, unsigned i) return Node::null(); } -void Skolemize::getSelfSel(const Datatype& dt, - const DatatypeConstructor& dc, +void Skolemize::getSelfSel(const DType& dt, + const DTypeConstructor& dc, Node n, TypeNode ntn, std::vector& selfSel) @@ -82,14 +82,14 @@ void Skolemize::getSelfSel(const Datatype& dt, TypeNode tspec; if (dt.isParametric()) { - tspec = TypeNode::fromType( - dc.getSpecializedConstructorType(n.getType().toType())); + tspec = dc.getSpecializedConstructorType(n.getType()); Trace("sk-ind-debug") << "Specialized constructor type : " << tspec << std::endl; Assert(tspec.getNumChildren() == dc.getNumArgs()); } Trace("sk-ind-debug") << "Check self sel " << dc.getName() << " " << dt.getName() << std::endl; + NodeManager* nm = NodeManager::currentNM(); for (unsigned j = 0; j < dc.getNumArgs(); j++) { std::vector ssc; @@ -104,32 +104,17 @@ void Skolemize::getSelfSel(const Datatype& dt, } else { - TypeNode tn = TypeNode::fromType(dc[j].getRangeType()); + TypeNode tn = dc[j].getRangeType(); Trace("sk-ind-debug") << "Compare " << tn << " " << ntn << std::endl; if (tn == ntn) { ssc.push_back(n); } } - /* TODO: more than weak structural induction - else if( tn.isDatatype() && std::find( visited.begin(), visited.end(), tn - )==visited.end() ){ - visited.push_back( tn ); - const Datatype& dt = - ((DatatypeType)(subs[0].getType()).toType()).getDatatype(); - std::vector< Node > disj; - for( unsigned i=0; imkNode( - APPLY_SELECTOR_TOTAL, - dc.getSelectorInternal(n.getType().toType(), j), - n); + Node ss = nm->mkNode( + APPLY_SELECTOR_TOTAL, dc.getSelectorInternal(n.getType(), j), n); if (std::find(selfSel.begin(), selfSel.end(), ss) == selfSel.end()) { selfSel.push_back(ss); @@ -146,6 +131,7 @@ Node Skolemize::mkSkolemizedBody(Node f, Node& sub, std::vector& sub_vars) { + NodeManager* nm = NodeManager::currentNM(); Assert(sk.empty() || sk.size() == f[0].getNumChildren()); // calculate the variables and substitution std::vector ind_vars; @@ -220,17 +206,14 @@ Node Skolemize::mkSkolemizedBody(Node f, // the following constructs ~( R( x, k ) => ~P( x ) ) if (options::dtStcInduction() && tn.isDatatype()) { - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); std::vector disj; for (unsigned i = 0; i < dt.getNumConstructors(); i++) { std::vector selfSel; getSelfSel(dt, dt[i], k, tn, selfSel); std::vector conj; - conj.push_back( - NodeManager::currentNM() - ->mkNode(APPLY_TESTER, Node::fromExpr(dt[i].getTester()), k) - .negate()); + conj.push_back(nm->mkNode(APPLY_TESTER, dt[i].getTester(), k).negate()); for (unsigned j = 0; j < selfSel.size(); j++) { conj.push_back(ret.substitute(ind_vars[0], selfSel[j]).negate()); @@ -346,7 +329,7 @@ bool Skolemize::isInductionTerm(Node n) TypeNode tn = n.getType(); if (options::dtStcInduction() && tn.isDatatype()) { - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); return !dt.isCodatatype(); } if (options::intWfInduction() && n.getType().isInteger()) diff --git a/src/theory/quantifiers/skolemize.h b/src/theory/quantifiers/skolemize.h index f07bbdfd3..86af1ee1b 100644 --- a/src/theory/quantifiers/skolemize.h +++ b/src/theory/quantifiers/skolemize.h @@ -123,8 +123,8 @@ class Skolemize * applied to term n, whose return type in ntn, and stores * them in the vector selfSel. */ - static void getSelfSel(const Datatype& dt, - const DatatypeConstructor& dc, + static void getSelfSel(const DType& dt, + const DTypeConstructor& dc, Node n, TypeNode ntn, std::vector& selfSel); diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp index 61d891f75..69e0ef70a 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv.cpp @@ -467,8 +467,8 @@ Node CegSingleInv::getSolution(unsigned sol_index, bool rconsSygus) { Assert(d_sol != NULL); - const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); - Node varList = Node::fromExpr( dt.getSygusVarList() ); + const DType& dt = stn.getDType(); + Node varList = dt.getSygusVarList(); Node prog = d_quant[0][sol_index]; std::vector< Node > vars; Node s; @@ -478,8 +478,7 @@ Node CegSingleInv::getSolution(unsigned sol_index, || d_inst.empty()) { Trace("csi-sol") << "Get solution for (unconstrained) " << prog << std::endl; - s = d_qe->getTermEnumeration()->getEnumerateTerm( - TypeNode::fromType(dt.getSygusType()), 0); + s = d_qe->getTermEnumeration()->getEnumerateTerm(dt.getSygusType(), 0); } else { @@ -548,7 +547,7 @@ Node CegSingleInv::reconstructToSyntax(Node s, bool rconsSygus) { d_solution = s; - const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); + const DType& dt = stn.getDType(); //reconstruct the solution into sygus if necessary reconstructed = 0; @@ -621,7 +620,7 @@ Node CegSingleInv::reconstructToSyntax(Node s, } //make into lambda if( !dt.getSygusVarList().isNull() ){ - Node varList = Node::fromExpr( dt.getSygusVarList() ); + Node varList = dt.getSygusVarList(); return NodeManager::currentNM()->mkNode( LAMBDA, varList, sol ); }else{ return sol; diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp b/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp index 811210628..113da2acb 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/ce_guided_single_inv_sol.h" #include "expr/datatype.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "options/quantifiers_options.h" #include "theory/arith/arith_msum.h" @@ -27,7 +28,6 @@ #include "theory/quantifiers/term_enumeration.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" -#include "theory/theory_engine.h" using namespace CVC4::kind; using namespace std; @@ -132,7 +132,7 @@ Node CegSingleInvSol::reconstructSolution(Node sol, { TypeNode tn = it->first; Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Trace("csi-rcons") << "Terms to reconstruct of type " << dt.getName() << " : " << std::endl; for (std::map::iterator it2 = it->second.begin(); @@ -225,7 +225,7 @@ int CegSingleInvSol::collectReconstructNodes(Node t, TypeNode stn, int& status) d_rcons_to_status[stn][t] = -1; TypeNode tn = t.getType(); Assert(stn.isDatatype()); - const Datatype& dt = stn.getDatatype(); + const DType& dt = stn.getDType(); TermDbSygus* tds = d_qe->getTermDatabaseSygus(); SygusTypeInfo& sti = tds->getTypeInfo(stn); Assert(dt.isSygus()); @@ -240,7 +240,8 @@ int CegSingleInvSol::collectReconstructNodes(Node t, TypeNode stn, int& status) carg = sti.getOpConsNum(min_t); if( carg!=-1 ){ Trace("csi-rcons-debug") << " Type has operator." << std::endl; - d_reconstruct[id] = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); + d_reconstruct[id] = NodeManager::currentNM()->mkNode( + APPLY_CONSTRUCTOR, dt[carg].getConstructor()); status = 0; }else{ //check if kind is in syntax sort @@ -266,7 +267,7 @@ int CegSingleInvSol::collectReconstructNodes(Node t, TypeNode stn, int& status) if( tchildren.size()==dt[karg].getNumArgs() ){ Trace("csi-rcons-debug") << "Type for " << id << " has kind " << min_t.getKind() << ", recurse." << std::endl; status = 0; - Node cons = Node::fromExpr( dt[karg].getConstructor() ); + Node cons = dt[karg].getConstructor(); if( !collectReconstructNodes( id, tchildren, dt[karg], d_reconstruct_op[id][cons], status ) ){ Trace("csi-rcons-debug") << "...failure for " << id << " " << dt[karg].getName() << std::endl; d_reconstruct_op[id].erase( cons ); @@ -295,10 +296,10 @@ int CegSingleInvSol::collectReconstructNodes(Node t, TypeNode stn, int& status) //try to directly reconstruct from single argument std::vector< Node > tchildren; tchildren.push_back( min_t ); - TypeNode stnc = TypeNode::fromType( ((SelectorType)dt[ii][0].getType()).getRangeType() ); + TypeNode stnc = dt[ii][0].getRangeType(); Trace("csi-rcons-debug") << "...try identity function " << dt[ii].getSygusOp() << ", child type is " << stnc << std::endl; status = 0; - Node cons = Node::fromExpr( dt[ii].getConstructor() ); + Node cons = dt[ii].getConstructor(); if( !collectReconstructNodes( id, tchildren, dt[ii], d_reconstruct_op[id][cons], status ) ){ d_reconstruct_op[id].erase( cons ); status = 1; @@ -320,7 +321,7 @@ int CegSingleInvSol::collectReconstructNodes(Node t, TypeNode stn, int& status) { success = true; status = 0; - Node cons = Node::fromExpr( dt[index_found].getConstructor() ); + Node cons = dt[index_found].getConstructor(); Trace("csi-rcons-debug") << "Try alternative for " << id << ", matching " << dt[index_found].getName() << " with children : " << std::endl; for( unsigned i=0; i& ts, - const DatatypeConstructor& dtc, + const DTypeConstructor& dtc, std::vector& ids, int& status) { @@ -510,7 +511,7 @@ int CegSingleInvSol::allocate(Node n, TypeNode stn) if( it==d_rcons_to_id[stn].end() ){ int ret = d_id_count; if( Trace.isOn("csi-rcons-debug") ){ - const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); + const DType& dt = stn.getDType(); Trace("csi-rcons-debug") << "id " << ret << " : " << n << " " << dt.getName() << std::endl; } d_id_node[d_id_count] = n; @@ -696,7 +697,7 @@ Node CegSingleInvSol::builtinToSygusConst(Node c, TypeNode tn, int rcons_depth) d_builtin_const_to_sygus[tn][c] = c; return c; } - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl; if (!dt.isSygus()) @@ -716,8 +717,7 @@ Node CegSingleInvSol::builtinToSygusConst(Node c, TypeNode tn, int rcons_depth) int carg = ti.getOpConsNum(c); if (carg != -1) { - sc = nm->mkNode(APPLY_CONSTRUCTOR, - Node::fromExpr(dt[carg].getConstructor())); + sc = nm->mkNode(APPLY_CONSTRUCTOR, dt[carg].getConstructor()); } else { @@ -734,8 +734,7 @@ Node CegSingleInvSol::builtinToSygusConst(Node c, TypeNode tn, int rcons_depth) Node n = builtinToSygusConst(c, tnc, rcons_depth); if (!n.isNull()) { - sc = nm->mkNode( - APPLY_CONSTRUCTOR, Node::fromExpr(dt[ii].getConstructor()), n); + sc = nm->mkNode(APPLY_CONSTRUCTOR, dt[ii].getConstructor(), n); break; } } @@ -744,16 +743,14 @@ Node CegSingleInvSol::builtinToSygusConst(Node c, TypeNode tn, int rcons_depth) if (rcons_depth < 1000) { // accelerated, recursive reconstruction of constants - Kind pk = getPlusKind(TypeNode::fromType(dt.getSygusType())); + Kind pk = getPlusKind(dt.getSygusType()); if (pk != UNDEFINED_KIND) { int arg = ti.getKindConsNum(pk); if (arg != -1) { - Kind ck = - getComparisonKind(TypeNode::fromType(dt.getSygusType())); - Kind pkm = - getPlusKind(TypeNode::fromType(dt.getSygusType()), true); + Kind ck = getComparisonKind(dt.getSygusType()); + Kind pkm = getPlusKind(dt.getSygusType(), true); // get types Assert(dt[arg].getNumArgs() == 2); TypeNode tn1 = tds->getArgType(dt[arg], 0); @@ -780,7 +777,7 @@ Node CegSingleInvSol::builtinToSygusConst(Node c, TypeNode tn, int rcons_depth) Node sc1 = builtinToSygusConst(c1, tn1, rcons_depth); Assert(!sc1.isNull()); sc = nm->mkNode(APPLY_CONSTRUCTOR, - Node::fromExpr(dt[arg].getConstructor()), + dt[arg].getConstructor(), sc1, sc2); break; @@ -819,9 +816,9 @@ void CegSingleInvSol::registerType(TypeNode tn) TermDbSygus* tds = d_qe->getTermDatabaseSygus(); // ensure it is registered tds->registerSygusType(tn); - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); - TypeNode btn = TypeNode::fromType(dt.getSygusType()); + TypeNode btn = dt.getSygusType(); // for constant reconstruction Kind ck = getComparisonKind(btn); Node z = d_qe->getTermUtil()->getTypeValue(btn, 0); @@ -829,7 +826,7 @@ void CegSingleInvSol::registerType(TypeNode tn) // iterate over constructors for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { - Node n = Node::fromExpr(dt[i].getSygusOp()); + Node n = dt[i].getSygusOp(); if (n.getKind() != kind::BUILTIN && n.isConst()) { d_const_list[tn].push_back(n); @@ -927,7 +924,7 @@ bool CegSingleInvSol::getMatch(Node t, int index_start) { Assert(st.isDatatype()); - const Datatype& dt = static_cast(st.toType()).getDatatype(); + const DType& dt = st.getDType(); Assert(dt.isSygus()); std::map > kgens; std::vector gens; @@ -975,7 +972,7 @@ bool CegSingleInvSol::getMatch(Node t, return false; } -Node CegSingleInvSol::getGenericBase(TypeNode tn, const Datatype& dt, int c) +Node CegSingleInvSol::getGenericBase(TypeNode tn, const DType& dt, int c) { std::map::iterator it = d_generic_base[tn].find(c); if (it != d_generic_base[tn].end()) diff --git a/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.h b/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.h index c319080af..ed84c81b2 100644 --- a/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.h +++ b/src/theory/quantifiers/sygus/ce_guided_single_inv_sol.h @@ -106,7 +106,11 @@ private: int allocate( Node n, TypeNode stn ); // term t with sygus type st, returns inducted templated form of t int collectReconstructNodes( Node t, TypeNode stn, int& status ); - bool collectReconstructNodes( int pid, std::vector< Node >& ts, const DatatypeConstructor& dtc, std::vector< int >& ids, int& status ); + bool collectReconstructNodes(int pid, + std::vector& ts, + const DTypeConstructor& dtc, + std::vector& ids, + int& status); bool getPathToRoot( int id ); void setReconstructed( int id, Node n ); //get equivalent terms to n with top symbol k @@ -140,7 +144,7 @@ private: * This returns the builtin term that is the analog of an application of the * c^th constructor of dt to fresh variables. */ - Node getGenericBase(TypeNode tn, const Datatype& dt, int c); + Node getGenericBase(TypeNode tn, const DType& dt, int c); /** cache for the above function */ std::map > d_generic_base; /** get match diff --git a/src/theory/quantifiers/sygus/cegis_core_connective.cpp b/src/theory/quantifiers/sygus/cegis_core_connective.cpp index 0b9dd3b48..573e11426 100644 --- a/src/theory/quantifiers/sygus/cegis_core_connective.cpp +++ b/src/theory/quantifiers/sygus/cegis_core_connective.cpp @@ -197,7 +197,7 @@ bool CegisCoreConnective::processInitialize(Node conj, // candidate has the production rule gt -> AND( gt, gt ). Similarly for // precondition and OR. Assert(gt.isDatatype()); - const Datatype& gdt = gt.getDatatype(); + const DType& gdt = gt.getDType(); SygusTypeInfo& gti = d_tds->getTypeInfo(gt); for (unsigned r = 0; r < 2; r++) { @@ -211,12 +211,12 @@ bool CegisCoreConnective::processInitialize(Node conj, Kind rk = r == 0 ? OR : AND; int i = gti.getKindConsNum(rk); if (i != -1 && gdt[i].getNumArgs() == 2 - && TypeNode::fromType(gdt[i].getArgType(0)) == gt - && TypeNode::fromType(gdt[i].getArgType(1)) == gt) + && gdt[i].getArgType(0) == gt + && gdt[i].getArgType(1) == gt) { Trace("sygus-ccore-init") << " will do " << (r == 0 ? "pre" : "post") << "condition." << std::endl; - Node cons = Node::fromExpr(gdt[i].getConstructor()); + Node cons = gdt[i].getConstructor(); c.initialize(f, cons); // Register the symmetry breaking lemma: do not do top-level solutions // with this constructor (e.g. we want to enumerate literals, not diff --git a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp index 0cdfe4307..dd9af2c43 100644 --- a/src/theory/quantifiers/sygus/enum_stream_substitution.cpp +++ b/src/theory/quantifiers/sygus/enum_stream_substitution.cpp @@ -45,7 +45,7 @@ void EnumStreamPermutation::reset(Node value) d_value = value; // get variables in value's type TypeNode tn = value.getType(); - Node var_list = Node::fromExpr(tn.getDatatype().getSygusVarList()); + Node var_list = tn.getDType().getSygusVarList(); NodeManager* nm = NodeManager::currentNM(); // get subtypes in value's type SygusTypeInfo& ti = d_tds->getTypeInfo(tn); @@ -58,10 +58,10 @@ void EnumStreamPermutation::reset(Node value) // collect constructors for variable in all subtypes for (const TypeNode& stn : sf_types) { - const Datatype& dt = stn.getDatatype(); + const DType& dt = stn.getDType(); for (unsigned i = 0, size = dt.getNumConstructors(); i < size; ++i) { - if (dt[i].getNumArgs() == 0 && Node::fromExpr(dt[i].getSygusOp()) == v) + if (dt[i].getNumArgs() == 0 && dt[i].getSygusOp() == v) { Node cons = nm->mkNode(APPLY_CONSTRUCTOR, dt[i].getConstructor()); d_var_tn_cons[v][stn] = cons; @@ -337,7 +337,7 @@ void EnumStreamSubstitution::initialize(TypeNode tn) { d_tn = tn; // get variables in value's type - Node var_list = Node::fromExpr(tn.getDatatype().getSygusVarList()); + Node var_list = tn.getDType().getSygusVarList(); // get subtypes in value's type NodeManager* nm = NodeManager::currentNM(); SygusTypeInfo& ti = d_tds->getTypeInfo(tn); @@ -349,10 +349,10 @@ void EnumStreamSubstitution::initialize(TypeNode tn) // collect constructors for variable in all subtypes for (const TypeNode& stn : sf_types) { - const Datatype& dt = stn.getDatatype(); + const DType& dt = stn.getDType(); for (unsigned i = 0, size = dt.getNumConstructors(); i < size; ++i) { - if (dt[i].getNumArgs() == 0 && Node::fromExpr(dt[i].getSygusOp()) == v) + if (dt[i].getNumArgs() == 0 && dt[i].getSygusOp() == v) { d_var_tn_cons[v][stn] = nm->mkNode(APPLY_CONSTRUCTOR, dt[i].getConstructor()); diff --git a/src/theory/quantifiers/sygus/sygus_abduct.cpp b/src/theory/quantifiers/sygus/sygus_abduct.cpp index 0396aba86..a58c5d841 100644 --- a/src/theory/quantifiers/sygus/sygus_abduct.cpp +++ b/src/theory/quantifiers/sygus/sygus_abduct.cpp @@ -16,9 +16,11 @@ #include "theory/quantifiers/sygus/sygus_abduct.h" #include "expr/datatype.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "expr/sygus_datatype.h" #include "printer/sygus_print_callback.h" +#include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/quantifiers_rewriter.h" #include "theory/quantifiers/sygus/sygus_grammar_cons.h" @@ -85,13 +87,13 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, // if provided, we will associate it with the function-to-synthesize if (!abdGType.isNull()) { - Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus()); + Assert(abdGType.isDatatype() && abdGType.getDType().isSygus()); // must convert all constructors to version with bound variables in "vars" std::vector sdts; std::set unres; Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl; - Trace("sygus-abduct-debug") << abdGType.getDatatype() << std::endl; + Trace("sygus-abduct-debug") << abdGType.getDType().getName() << std::endl; // datatype types we need to process std::vector dtToProcess; @@ -99,7 +101,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, std::map dtProcessed; dtToProcess.push_back(abdGType); std::stringstream ssutn0; - ssutn0 << abdGType.getDatatype().getName() << "_s"; + ssutn0 << abdGType.getDType().getName() << "_s"; TypeNode abdTNew = nm->mkSort(ssutn0.str(), ExprManager::SORT_FLAG_PLACEHOLDER); unres.insert(abdTNew.toType()); @@ -126,8 +128,8 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, std::vector dtNextToProcess; for (const TypeNode& curr : dtToProcess) { - Assert(curr.isDatatype() && curr.getDatatype().isSygus()); - const Datatype& dtc = curr.getDatatype(); + Assert(curr.isDatatype() && curr.getDType().isSygus()); + const DType& dtc = curr.getDType(); std::stringstream ssdtn; ssdtn << dtc.getName() << "_s"; sdts.push_back(SygusDatatype(ssdtn.str())); @@ -136,7 +138,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, << std::endl; for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++) { - Node op = Node::fromExpr(dtc[j].getSygusOp()); + Node op = dtc[j].getSygusOp(); // apply the substitution to the argument Node ops = op.substitute( syms.begin(), syms.end(), varlist.begin(), varlist.end()); @@ -145,14 +147,14 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, std::vector cargs; for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++) { - TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k)); + TypeNode argt = dtc[j].getArgType(k); std::map::iterator itdp = dtProcessed.find(argt); TypeNode argtNew; if (itdp == dtProcessed.end()) { std::stringstream ssutn; - ssutn << argt.getDatatype().getName() << "_s"; + ssutn << argt.getDType().getName() << "_s"; argtNew = nm->mkSort(ssutn.str(), ExprManager::SORT_FLAG_PLACEHOLDER); Trace("sygus-abduct-debug") @@ -196,7 +198,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, } Trace("sygus-abduct-debug") << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl; - TypeNode stn = TypeNode::fromType(dtc.getSygusType()); + TypeNode stn = dtc.getSygusType(); sdts.back().initializeDatatype( stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll()); } @@ -222,7 +224,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, Trace("sygus-abduct-debug") << "Made datatype types:" << std::endl; for (unsigned j = 0, ndts = datatypeTypes.size(); j < ndts; j++) { - const Datatype& dtj = datatypeTypes[j].getDatatype(); + const DType& dtj = TypeNode::fromType(datatypeTypes[j]).getDType(); Trace("sygus-abduct-debug") << "#" << j << ": " << dtj << std::endl; for (unsigned k = 0, ncons = dtj.getNumConstructors(); k < ncons; k++) { diff --git a/src/theory/quantifiers/sygus/sygus_enumerator.cpp b/src/theory/quantifiers/sygus/sygus_enumerator.cpp index e4c23977e..f6ec58f56 100644 --- a/src/theory/quantifiers/sygus/sygus_enumerator.cpp +++ b/src/theory/quantifiers/sygus/sygus_enumerator.cpp @@ -35,7 +35,7 @@ void SygusEnumerator::initialize(Node e) d_enum = e; d_etype = d_enum.getType(); Assert(d_etype.isDatatype()); - Assert(d_etype.getDatatype().isSygus()); + Assert(d_etype.getDType().isSygus()); d_tlEnum = getMasterEnumForType(d_etype); d_abortSize = options::sygusAbortSize(); @@ -50,7 +50,7 @@ void SygusEnumerator::initialize(Node e) TNode agt = ag; TNode truent = truen; Assert(d_tcache.find(d_etype) != d_tcache.end()); - const Datatype& dt = d_etype.getDatatype(); + const DType& dt = d_etype.getDType(); for (const Node& lem : sbl) { if (!d_tds->isSymBreakLemmaTemplate(lem)) @@ -86,7 +86,7 @@ void SygusEnumerator::initialize(Node e) { if (a == e) { - Node cons = Node::fromExpr(dt[tst].getConstructor()); + Node cons = dt[tst].getConstructor(); Trace("sygus-enum") << " ...unit exclude constructor #" << tst << ", constructor " << cons << std::endl; d_sbExcTlCons.insert(cons); @@ -168,7 +168,7 @@ void SygusEnumerator::TermCache::initialize(Node e, // not a datatype, finish return; } - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { // not a sygus datatype, finish @@ -200,7 +200,7 @@ void SygusEnumerator::TermCache::initialize(Node e, // record type information for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode tn = TypeNode::fromType(dt[i].getArgType(j)); + TypeNode tn = dt[i].getArgType(j); argTypes[i].push_back(tn); } } @@ -544,7 +544,7 @@ void SygusEnumerator::initializeTermCache(TypeNode tn) SygusEnumerator::TermEnum* SygusEnumerator::getMasterEnumForType(TypeNode tn) { - if (tn.isDatatype() && tn.getDatatype().isSygus()) + if (tn.isDatatype() && tn.getDType().isSygus()) { std::map::iterator it = d_masterEnum.find(tn); if (it != d_masterEnum.end()) @@ -627,11 +627,11 @@ Node SygusEnumerator::TermEnumMaster::getCurrent() d_currTermSet = true; // construct based on the children std::vector children; - const Datatype& dt = d_tn.getDatatype(); + const DType& dt = d_tn.getDType(); Assert(d_consNum > 0 && d_consNum <= d_ccCons.size()); // get the current constructor number unsigned cnum = d_ccCons[d_consNum - 1]; - children.push_back(Node::fromExpr(dt[cnum].getConstructor())); + children.push_back(dt[cnum].getConstructor()); // add the current of each child to children for (unsigned i = 0, nargs = dt[cnum].getNumArgs(); i < nargs; i++) { diff --git a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp index 42ddbbb7d..0cc57e0ec 100644 --- a/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp +++ b/src/theory/quantifiers/sygus/sygus_eval_unfold.cpp @@ -47,7 +47,7 @@ void SygusEvalUnfold::registerEvalTerm(Node n) TypeNode tn = n[0].getType(); // since n[0] is an evaluation head, we know tn is a sygus datatype Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); if (n[0].getKind() == APPLY_CONSTRUCTOR) { @@ -57,7 +57,7 @@ void SygusEvalUnfold::registerEvalTerm(Node n) } // register this evaluation term with its head d_evals[n[0]].push_back(n); - Node var_list = Node::fromExpr(dt.getSygusVarList()); + Node var_list = dt.getSygusVarList(); d_eval_args[n[0]].push_back(std::vector()); for (unsigned j = 1, size = n.getNumChildren(); j < size; j++) { @@ -109,7 +109,7 @@ void SygusEvalUnfold::registerModelValue(Node a, bool hasSymCons = sti.hasSubtermSymbolicCons(); // n occurs as an evaluation head, thus it has sygus datatype type Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); Trace("sygus-eval-unfold") << "SygusEvalUnfold: Register model value : " << vn << " for " << n @@ -121,7 +121,7 @@ void SygusEvalUnfold::registerModelValue(Node a, Node bTerm = d_tds->sygusToBuiltin(vn, tn); Trace("sygus-eval-unfold") << "Built-in term : " << bTerm << std::endl; std::vector vars; - Node var_list = Node::fromExpr(dt.getSygusVarList()); + Node var_list = dt.getSygusVarList(); for (const Node& v : var_list) { vars.push_back(v); @@ -240,22 +240,20 @@ Node SygusEvalUnfold::unfold(Node en, } TypeNode headType = en[0].getType(); - Type headTypeT = headType.toType(); NodeManager* nm = NodeManager::currentNM(); - const Datatype& dt = headType.getDatatype(); + const DType& dt = headType.getDType(); unsigned i = datatypes::utils::indexOf(ev.getOperator()); if (track_exp) { // explanation - Node ee = - nm->mkNode(APPLY_TESTER, Node::fromExpr(dt[i].getTester()), en[0]); + Node ee = nm->mkNode(APPLY_TESTER, dt[i].getTester(), en[0]); if (std::find(exp.begin(), exp.end(), ee) == exp.end()) { exp.push_back(ee); } } // if we are a symbolic constructor, unfolding returns the subterm itself - Node sop = Node::fromExpr(dt[i].getSygusOp()); + Node sop = dt[i].getSygusOp(); if (sop.getAttribute(SygusAnyConstAttribute())) { Trace("sygus-eval-unfold-debug") @@ -272,7 +270,7 @@ Node SygusEvalUnfold::unfold(Node en, else { Node ret = nm->mkNode( - APPLY_SELECTOR_TOTAL, dt[i].getSelectorInternal(headTypeT, 0), en[0]); + APPLY_SELECTOR_TOTAL, dt[i].getSelectorInternal(headType, 0), en[0]); Trace("sygus-eval-unfold-debug") << "...return (from constructor) " << ret << std::endl; return ret; @@ -295,7 +293,7 @@ Node SygusEvalUnfold::unfold(Node en, else { s = nm->mkNode( - APPLY_SELECTOR_TOTAL, dt[i].getSelectorInternal(headTypeT, j), en[0]); + APPLY_SELECTOR_TOTAL, dt[i].getSelectorInternal(headType, j), en[0]); } cc.push_back(s); if (track_exp) diff --git a/src/theory/quantifiers/sygus/sygus_explain.cpp b/src/theory/quantifiers/sygus/sygus_explain.cpp index cf1993efb..8fb7cf2e7 100644 --- a/src/theory/quantifiers/sygus/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus/sygus_explain.cpp @@ -14,6 +14,7 @@ #include "theory/quantifiers/sygus/sygus_explain.h" +#include "expr/dtype.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -138,7 +139,7 @@ void SygusExplain::getExplanationForEquality(Node n, return; } Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR); - const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); int i = datatypes::utils::indexOf(vn.getOperator()); Node tst = datatypes::utils::mkTester(n, i, dt); exp.push_back(tst); @@ -147,9 +148,7 @@ void SygusExplain::getExplanationForEquality(Node n, if (cexc.find(j) == cexc.end()) { Node sel = NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[i].getSelectorInternal(tn.toType(), j)), - n); + kind::APPLY_SELECTOR_TOTAL, dt[i].getSelectorInternal(tn, j), n); getExplanationForEquality(sel, vn[j], exp); } } @@ -227,7 +226,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, trb.replaceChild(i, vn[i]); } } - const Datatype& dt = ((DatatypeType)ntn.toType()).getDatatype(); + const DType& dt = ntn.getDType(); int cindex = datatypes::utils::indexOf(vn.getOperator()); Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); Node tst = datatypes::utils::mkTester(n, cindex, dt); @@ -245,9 +244,7 @@ void SygusExplain::getExplanationFor(TermRecBuild& trb, for (unsigned i = 0; i < vn.getNumChildren(); i++) { Node sel = NodeManager::currentNM()->mkNode( - kind::APPLY_SELECTOR_TOTAL, - Node::fromExpr(dt[cindex].getSelectorInternal(ntn.toType(), i)), - n); + kind::APPLY_SELECTOR_TOTAL, dt[cindex].getSelectorInternal(ntn, i), n); Node vnr_c = vnr.isNull() ? vnr : (vn[i] == vnr[i] ? Node::null() : vnr[i]); if (cexc.find(i) == cexc.end()) { diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp index b1b1ea62c..d9ce08d49 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.cpp @@ -14,6 +14,7 @@ #include "theory/quantifiers/sygus/sygus_grammar_red.h" +#include "expr/dtype.h" #include "expr/sygus_datatype.h" #include "options/quantifiers_options.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -35,14 +36,14 @@ void SygusRedundantCons::initialize(QuantifiersEngine* qe, TypeNode tn) Assert(tn.isDatatype()); TermDbSygus* tds = qe->getTermDatabaseSygus(); tds->registerSygusType(tn); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); - TypeNode btn = TypeNode::fromType(dt.getSygusType()); + TypeNode btn = dt.getSygusType(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { Trace("sygus-red") << " Is " << dt[i].getName() << " a redundant operator?" << std::endl; - Node sop = Node::fromExpr(dt[i].getSygusOp()); + Node sop = dt[i].getSygusOp(); if (sop.getAttribute(SygusAnyConstAttribute())) { // the any constant constructor is never redundant @@ -101,7 +102,7 @@ void SygusRedundantCons::initialize(QuantifiersEngine* qe, TypeNode tn) void SygusRedundantCons::getRedundant(std::vector& indices) { - const Datatype& dt = static_cast(d_type.toType()).getDatatype(); + const DType& dt = d_type.getDType(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { if (isRedundant(i)) @@ -118,7 +119,7 @@ bool SygusRedundantCons::isRedundant(unsigned i) } void SygusRedundantCons::getGenericList(TermDbSygus* tds, - const Datatype& dt, + const DType& dt, unsigned c, unsigned index, std::map& pre, diff --git a/src/theory/quantifiers/sygus/sygus_grammar_red.h b/src/theory/quantifiers/sygus/sygus_grammar_red.h index 317892723..f0743027e 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_red.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_red.h @@ -113,7 +113,7 @@ class SygusRedundantCons * to terms. */ void getGenericList(TermDbSygus* tds, - const Datatype& dt, + const DType& dt, unsigned c, unsigned index, std::map& pre, diff --git a/src/theory/quantifiers/sygus/sygus_repair_const.cpp b/src/theory/quantifiers/sygus/sygus_repair_const.cpp index 5511adb18..9ab94d1bc 100644 --- a/src/theory/quantifiers/sygus/sygus_repair_const.cpp +++ b/src/theory/quantifiers/sygus/sygus_repair_const.cpp @@ -70,7 +70,7 @@ void SygusRepairConst::registerSygusType(TypeNode tn, // "any constant" constructors return; } - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { // may have recursed to a non-sygus-datatype @@ -83,7 +83,7 @@ void SygusRepairConst::registerSygusType(TypeNode tn, } for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { - const DatatypeConstructor& dtc = dt[i]; + const DTypeConstructor& dtc = dt[i]; // recurse on all subfields for (unsigned j = 0, nargs = dtc.getNumArgs(); j < nargs; j++) { @@ -366,14 +366,14 @@ bool SygusRepairConst::isRepairable(Node n, bool useConstantsAsHoles) } TypeNode tn = n.getType(); Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { return false; } Node op = n.getOperator(); unsigned cindex = datatypes::utils::indexOf(op); - Node sygusOp = Node::fromExpr(dt[cindex].getSygusOp()); + Node sygusOp = dt[cindex].getSygusOp(); if (sygusOp.getAttribute(SygusAnyConstAttribute())) { // if it represents "any constant" then it is repairable diff --git a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp index f87b906e1..052546c0e 100644 --- a/src/theory/quantifiers/sygus/sygus_unif_strat.cpp +++ b/src/theory/quantifiers/sygus/sygus_unif_strat.cpp @@ -14,6 +14,7 @@ #include "theory/quantifiers/sygus/sygus_unif_strat.h" +#include "expr/dtype.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/sygus_eval_unfold.h" #include "theory/quantifiers/sygus/sygus_unif.h" @@ -141,8 +142,7 @@ void SygusUnifStrategy::registerStrategyPoint(Node et, if (d_einfo.find(et) == d_einfo.end()) { Trace("sygus-unif-debug") - << "...register " << et << " for " - << static_cast(tn.toType()).getDatatype().getName(); + << "...register " << et << " for " << tn.getDType().getName(); Trace("sygus-unif-debug") << ", role = " << enum_role << ", in search = " << inSearch << std::endl; d_einfo[et].initialize(enum_role); @@ -196,8 +196,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) ee = nm->mkSkolem("ee", tn); eti.d_enum[erole] = ee; Trace("sygus-unif-debug") - << "...enumerator " << ee << " for " - << static_cast(tn.toType()).getDatatype().getName() + << "...enumerator " << ee << " for " << tn.getDType().getName() << ", role = " << erole << std::endl; } else @@ -217,7 +216,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) // we know this is a sygus datatype since it is either the top-level type // in the strategy graph, or was recursed by a strategy we inferred. Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); std::map > cop_to_strat; @@ -232,8 +231,8 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) bool search_this = false; for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++) { - Node cop = Node::fromExpr(dt[j].getConstructor()); - Node op = Node::fromExpr(dt[j].getSygusOp()); + Node cop = dt[j].getConstructor(); + Node op = dt[j].getSygusOp(); Trace("sygus-unif-debug") << "--- Infer strategy from " << cop << " with sygus op " << op << "..." << std::endl; @@ -244,8 +243,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) std::vector sktns; for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++) { - Type t = dt[j][k].getRangeType(); - TypeNode ttn = TypeNode::fromType(t); + TypeNode ttn = dt[j][k].getRangeType(); Node kv = nm->mkSkolem("ut", ttn); sks.push_back(kv); cop_to_sks[cop].push_back(kv); @@ -255,7 +253,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) Node ut = nm->mkNode(APPLY_CONSTRUCTOR, utchildren); std::vector echildren; echildren.push_back(ut); - Node sbvl = Node::fromExpr(dt.getSygusVarList()); + Node sbvl = dt.getSygusVarList(); for (const Node& sbv : sbvl) { echildren.push_back(sbv); @@ -482,10 +480,8 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) for (unsigned k = 0, size = cop_to_carg_list[cop].size(); k < size; k++) { TypeNode ctn = sktns[cop_to_carg_list[cop][k]]; - Trace("sygus-unif-debug") - << " Child type " << k << " : " - << static_cast(ctn.toType()).getDatatype().getName() - << std::endl; + Trace("sygus-unif-debug") << " Child type " << k << " : " + << ctn.getDType().getName() << std::endl; cop_to_child_types[cop].push_back(ctn); } // if there are checks on the consistency of child types wrt strategies, @@ -578,14 +574,10 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) { // it is templated, allocate a fresh variable et = nm->mkSkolem("et", ct); - Trace("sygus-unif-debug") - << "...enumerate " << et << " of type " - << ((DatatypeType)ct.toType()).getDatatype().getName(); + Trace("sygus-unif-debug") << "...enumerate " << et << " of type " + << ct.getDType().getName(); Trace("sygus-unif-debug") << " for arg " << j << " of " - << static_cast(tn.toType()) - .getDatatype() - .getName() - << std::endl; + << tn.getDType().getName() << std::endl; registerStrategyPoint(et, ct, erole_c, true); d_einfo[et].d_template = cop_to_child_templ[cop][j]; d_einfo[et].d_template_arg = cop_to_child_templ_arg[cop][j]; @@ -595,8 +587,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) else { Trace("sygus-unif-debug") - << "...child type enumerate " - << ((DatatypeType)ct.toType()).getDatatype().getName() + << "...child type enumerate " << ct.getDType().getName() << ", node role = " << nrole_c << std::endl; // otherwise use the previous Assert(d_tinfo[ct].d_enum.find(erole_c) @@ -631,9 +622,7 @@ void SygusUnifStrategy::buildStrategyGraph(TypeNode tn, NodeRole nrole) { Trace("sygus-unif") << "Initialized strategy " << strat; Trace("sygus-unif") - << " for " - << static_cast(tn.toType()).getDatatype().getName() - << ", operator " << cop; + << " for " << tn.getDType().getName() << ", operator " << cop; Trace("sygus-unif") << ", #children = " << cons_strat->d_cenum.size() << ", solution template = (lambda ( "; for (const Node& targ : cons_strat->d_sol_templ_args) @@ -707,10 +696,8 @@ void SygusUnifStrategy::staticLearnRedundantOps( std::map::iterator itn = d_einfo.find(e); Assert(itn != d_einfo.end()); // see if there is anything we can eliminate - Trace("sygus-unif") - << "* Search enumerator #" << i << " : type " - << ((DatatypeType)e.getType().toType()).getDatatype().getName() - << " : "; + Trace("sygus-unif") << "* Search enumerator #" << i << " : type " + << e.getType().getDType().getName() << " : "; Trace("sygus-unif") << e << " has " << itn->second.d_enum_slave.size() << " slaves:" << std::endl; for (unsigned j = 0; j < itn->second.d_enum_slave.size(); j++) @@ -734,8 +721,7 @@ void SygusUnifStrategy::staticLearnRedundantOps( for (std::pair >& nce : needs_cons) { Node em = nce.first; - const Datatype& dt = - static_cast(em.getType().toType()).getDatatype(); + const DType& dt = em.getType().getDType(); std::vector lemmas; for (std::pair& nc : nce.second) { @@ -819,21 +805,18 @@ void SygusUnifStrategy::staticLearnRedundantOps( // arguments of ITE are the same BOOL type if (restrictions.d_iteReturnBoolConst) { - const Datatype& dt = - static_cast(etn.toType()).getDatatype(); - Node op = Node::fromExpr(dt[cindex].getSygusOp()); - TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType()); + const DType& dt = etn.getDType(); + Node op = dt[cindex].getSygusOp(); + TypeNode sygus_tn = dt.getSygusType(); if (op.getKind() == kind::BUILTIN - && NodeManager::operatorToKind(op) == ITE - && sygus_tn.isBoolean() - && (TypeNode::fromType(dt[cindex].getArgType(1)) - == TypeNode::fromType(dt[cindex].getArgType(2)))) + && NodeManager::operatorToKind(op) == ITE && sygus_tn.isBoolean() + && (dt[cindex].getArgType(1) == dt[cindex].getArgType(2))) { unsigned ncons = dt.getNumConstructors(), indexT = ncons, indexF = ncons; for (unsigned k = 0; k < ncons; ++k) { - Node op_arg = Node::fromExpr(dt[k].getSygusOp()); + Node op_arg = dt[k].getSygusOp(); if (dt[k].getNumArgs() > 0 || !op_arg.isConst()) { continue; @@ -867,14 +850,14 @@ void SygusUnifStrategy::staticLearnRedundantOps( } } // get the current datatype - const Datatype& dt = static_cast(etn.toType()).getDatatype(); + const DType& dt = etn.getDType(); // do not use recursive Boolean connectives for conditions of ITEs if (nrole == role_ite_condition && restrictions.d_iteCondOnlyAtoms) { - TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType()); + TypeNode sygus_tn = dt.getSygusType(); for (unsigned j = 0, size = dt.getNumConstructors(); j < size; j++) { - Node op = Node::fromExpr(dt[j].getSygusOp()); + Node op = dt[j].getSygusOp(); Trace("sygus-strat-slearn") << "...for ite condition, look at operator : " << op << std::endl; if (op.isConst() && dt[j].getNumArgs() == 0) @@ -894,7 +877,7 @@ void SygusUnifStrategy::staticLearnRedundantOps( bool type_ok = true; for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++) { - TypeNode tn = TypeNode::fromType(dt[j].getArgType(k)); + TypeNode tn = dt[j].getArgType(k); if (tn != etn) { type_ok = false; @@ -991,8 +974,7 @@ void SygusUnifStrategy::debugPrint( indent(c, ind); Trace(c) << e << " :: node role : " << nrole; - Trace(c) << ", type : " - << static_cast(etn.toType()).getDatatype().getName(); + Trace(c) << ", type : " << etn.getDType().getName(); if (ei.isConditional()) { Trace(c) << ", conditional"; diff --git a/src/theory/quantifiers/sygus/synth_conjecture.cpp b/src/theory/quantifiers/sygus/synth_conjecture.cpp index ca4feda32..e30f9771c 100644 --- a/src/theory/quantifiers/sygus/synth_conjecture.cpp +++ b/src/theory/quantifiers/sygus/synth_conjecture.cpp @@ -1050,7 +1050,7 @@ void SynthConjecture::printSynthSolution(std::ostream& out) Node prog = d_embed_quant[0][i]; int status = statuses[i]; TypeNode tn = prog.getType(); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); std::stringstream ss; ss << prog; std::string f(ss.str()); @@ -1113,7 +1113,7 @@ void SynthConjecture::printSynthSolution(std::ostream& out) // pvs stores the variables that will be printed in the argument list // below. std::vector pvs; - Node vl = Node::fromExpr(dt.getSygusVarList()); + Node vl = dt.getSygusVarList(); if (!vl.isNull()) { Assert(vl.getKind() == BOUND_VAR_LIST); @@ -1176,9 +1176,9 @@ bool SynthConjecture::getSynthSolutions( } // convert to lambda TypeNode tn = d_embed_quant[0][i].getType(); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Node fvar = d_quant[0][i]; - Node bvl = Node::fromExpr(dt.getSygusVarList()); + Node bvl = dt.getSygusVarList(); if (!bvl.isNull()) { // since we don't have function subtyping, this assertion should only diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 08fb58e40..bce46fb6b 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -67,9 +67,9 @@ TNode TermDbSygus::getFreeVar( TypeNode tn, int i, bool useSygusType ) { TypeNode vtn = tn; if( useSygusType ){ if( tn.isDatatype() ){ - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); if( !dt.getSygusType().isNull() ){ - vtn = TypeNode::fromType( dt.getSygusType() ); + vtn = dt.getSygusType(); sindex = 1; } } @@ -77,7 +77,7 @@ TNode TermDbSygus::getFreeVar( TypeNode tn, int i, bool useSygusType ) { while( i>=(int)d_fv[sindex][tn].size() ){ std::stringstream ss; if( tn.isDatatype() ){ - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); ss << "fv_" << dt.getName() << "_" << i; }else{ ss << "fv_" << tn << "_" << i; @@ -126,11 +126,8 @@ bool TermDbSygus::hasFreeVar( Node n ) { Node TermDbSygus::getProxyVariable(TypeNode tn, Node c) { Assert(tn.isDatatype()); - Assert(static_cast(tn.toType()).getDatatype().isSygus()); - Assert( - TypeNode::fromType( - static_cast(tn.toType()).getDatatype().getSygusType()) - .isComparableTo(c.getType())); + Assert(tn.getDType().isSygus()); + Assert(tn.getDType().getSygusType().isComparableTo(c.getType())); std::map::iterator it = d_proxy_vars[tn].find(c); if (it == d_proxy_vars[tn].end()) @@ -146,9 +143,9 @@ Node TermDbSygus::getProxyVariable(TypeNode tn, Node c) } else { - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); k = NodeManager::currentNM()->mkNode( - APPLY_CONSTRUCTOR, Node::fromExpr(dt[anyC].getConstructor()), c); + APPLY_CONSTRUCTOR, dt[anyC].getConstructor(), c); } d_proxy_vars[tn][c] = k; return k; @@ -161,7 +158,7 @@ TypeNode TermDbSygus::getSygusTypeForVar( Node v ) { return d_fv_stype[v]; } -Node TermDbSygus::mkGeneric(const Datatype& dt, +Node TermDbSygus::mkGeneric(const DType& dt, unsigned c, std::map& var_count, std::map& pre, @@ -181,7 +178,7 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, a = it->second; Trace("sygus-db-debug") << "From pre: " << a << std::endl; }else{ - TypeNode tna = TypeNode::fromType(dt[c].getArgType(i)); + TypeNode tna = dt[c].getArgType(i); a = getFreeVarInc( tna, var_count, true ); } Trace("sygus-db-debug") @@ -194,7 +191,7 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, return ret; } -Node TermDbSygus::mkGeneric(const Datatype& dt, +Node TermDbSygus::mkGeneric(const DType& dt, int c, std::map& pre, bool doBetaRed) @@ -203,7 +200,7 @@ Node TermDbSygus::mkGeneric(const Datatype& dt, return mkGeneric(dt, c, var_count, pre, doBetaRed); } -Node TermDbSygus::mkGeneric(const Datatype& dt, int c, bool doBetaRed) +Node TermDbSygus::mkGeneric(const DType& dt, int c, bool doBetaRed) { std::map pre; return mkGeneric(dt, c, pre, doBetaRed); @@ -294,7 +291,7 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) } Trace("sygus-db-debug") << "SygusToBuiltin : compute for " << n << ", type = " << tn << std::endl; - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { return n; @@ -306,7 +303,7 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) std::map pre; for (unsigned j = 0, size = n.getNumChildren(); j < size; j++) { - pre[j] = sygusToBuiltin(n[j], TypeNode::fromType(dt[i].getArgType(j))); + pre[j] = sygusToBuiltin(n[j], dt[i].getArgType(j)); Trace("sygus-db-debug") << "sygus to builtin " << n[j] << " is " << pre[j] << std::endl; } @@ -326,7 +323,7 @@ Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn) // map to builtin variable type int fv_num = getVarNum(n); Assert(!dt.getSygusType().isNull()); - TypeNode vtn = TypeNode::fromType(dt.getSygusType()); + TypeNode vtn = dt.getSygusType(); Node ret = getFreeVar(vtn, fv_num); return ret; } @@ -341,7 +338,7 @@ unsigned TermDbSygus::getSygusTermSize( Node n ){ { sum += getSygusTermSize(n[i]); } - const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr()); + const DType& dt = datatypes::utils::datatypeOf(n.getOperator()); int cindex = datatypes::utils::indexOf(n.getOperator()); Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); unsigned weight = dt[cindex].getWeight(); @@ -362,7 +359,7 @@ bool TermDbSygus::registerSygusType(TypeNode tn) { return false; } - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { return false; @@ -407,12 +404,10 @@ void TermDbSygus::registerEnumerator(Node e, TypeNode stn = sf_types[i]; Assert(stn.isDatatype()); SygusTypeInfo& sti = getTypeInfo(stn); - const Datatype& dt = stn.getDatatype(); + const DType& dt = stn.getDType(); int anyC = sti.getAnyConstantConsNum(); for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { - Expr sop = dt[i].getSygusOp(); - Assert(!sop.isNull()); bool isAnyC = static_cast(i) == anyC; if (isAnyC && !useSymbolicCons) { @@ -497,7 +492,7 @@ void TermDbSygus::registerEnumerator(Node e, // sygus stream are to find many solutions to an easy problem, where // the bottleneck often becomes the large number of "exclude the current // solution" clauses. - const Datatype& dt = et.getDatatype(); + const DType& dt = et.getDType(); if (options::sygusStream() || (!eti.hasIte() && !dt.getSygusType().isBoolean())) { @@ -767,8 +762,7 @@ unsigned TermDbSygus::getSelectorWeight(TypeNode tn, Node sel) { d_sel_weight[tn].clear(); itsw = d_sel_weight.find(tn); - Type t = tn.toType(); - const Datatype& dt = static_cast(t).getDatatype(); + const DType& dt = tn.getDType(); Trace("sygus-db") << "Compute selector weights for " << dt.getName() << std::endl; for (unsigned i = 0, size = dt.getNumConstructors(); i < size; i++) @@ -776,7 +770,7 @@ unsigned TermDbSygus::getSelectorWeight(TypeNode tn, Node sel) unsigned cw = dt[i].getWeight(); for (unsigned j = 0, size2 = dt[i].getNumArgs(); j < size2; j++) { - Node csel = Node::fromExpr(dt[i].getSelectorInternal(t, j)); + Node csel = dt[i].getSelectorInternal(tn, j); std::map::iterator its = itsw->second.find(csel); if (its == itsw->second.end() || cw < its->second) { @@ -790,14 +784,15 @@ unsigned TermDbSygus::getSelectorWeight(TypeNode tn, Node sel) return itsw->second[sel]; } -TypeNode TermDbSygus::getArgType(const DatatypeConstructor& c, unsigned i) const +TypeNode TermDbSygus::getArgType(const DTypeConstructor& c, unsigned i) const { Assert(i < c.getNumArgs()); - return TypeNode::fromType( - static_cast(c[i].getType()).getRangeType()); + return c.getArgType(i); } -bool TermDbSygus::isTypeMatch( const DatatypeConstructor& c1, const DatatypeConstructor& c2 ) { +bool TermDbSygus::isTypeMatch(const DTypeConstructor& c1, + const DTypeConstructor& c2) +{ if( c1.getNumArgs()!=c2.getNumArgs() ){ return false; }else{ @@ -818,10 +813,10 @@ bool TermDbSygus::isSymbolicConsApp(Node n) const } TypeNode tn = n.getType(); Assert(tn.isDatatype()); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); unsigned cindex = datatypes::utils::indexOf(n.getOperator()); - Node sygusOp = Node::fromExpr(dt[cindex].getSygusOp()); + Node sygusOp = dt[cindex].getSygusOp(); // it is symbolic if it represents "any constant" return sygusOp.getAttribute(SygusAnyConstAttribute()); } @@ -834,12 +829,12 @@ bool TermDbSygus::canConstructKind(TypeNode tn, Assert(isRegistered(tn)); SygusTypeInfo& ti = getTypeInfo(tn); int c = ti.getKindConsNum(k); - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (c != -1) { for (unsigned i = 0, nargs = dt[c].getNumArgs(); i < nargs; i++) { - argts.push_back(TypeNode::fromType(dt[c].getArgType(i))); + argts.push_back(dt[c].getArgType(i)); } return true; } diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 76b5039f6..6d328ddca 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -19,6 +19,7 @@ #include +#include "expr/dtype.h" #include "theory/evaluator.h" #include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/fun_def_evaluator.h" @@ -229,18 +230,18 @@ class TermDbSygus { * If doBetaRed is true, then lambda operators are eagerly eliminated via * beta reduction. */ - Node mkGeneric(const Datatype& dt, + Node mkGeneric(const DType& dt, unsigned c, std::map& var_count, std::map& pre, bool doBetaRed = true); /** same as above, but with empty var_count */ - Node mkGeneric(const Datatype& dt, + Node mkGeneric(const DType& dt, int c, std::map& pre, bool doBetaRed = true); /** same as above, but with empty pre */ - Node mkGeneric(const Datatype& dt, int c, bool doBetaRed = true); + Node mkGeneric(const DType& dt, int c, bool doBetaRed = true); /** makes a symbolic term concrete * * Given a sygus datatype term n of type tn with holes (symbolic constructor @@ -413,9 +414,9 @@ class TermDbSygus { /** get the weight of the selector, where tn is the domain of sel */ unsigned getSelectorWeight(TypeNode tn, Node sel); /** get arg type */ - TypeNode getArgType(const DatatypeConstructor& c, unsigned i) const; + TypeNode getArgType(const DTypeConstructor& c, unsigned i) const; /** Do constructors c1 and c2 have the same type? */ - bool isTypeMatch( const DatatypeConstructor& c1, const DatatypeConstructor& c2 ); + bool isTypeMatch(const DTypeConstructor& c1, const DTypeConstructor& c2); /** return whether n is an application of a symbolic constructor */ bool isSymbolicConsApp(Node n) const; /** can construct kind diff --git a/src/theory/quantifiers/sygus/type_info.cpp b/src/theory/quantifiers/sygus/type_info.cpp index 71ccd60c9..a17a60927 100644 --- a/src/theory/quantifiers/sygus/type_info.cpp +++ b/src/theory/quantifiers/sygus/type_info.cpp @@ -15,6 +15,7 @@ #include "theory/quantifiers/sygus/type_info.h" #include "base/check.h" +#include "expr/dtype.h" #include "expr/sygus_datatype.h" #include "theory/datatypes/theory_datatypes_utils.h" #include "theory/quantifiers/sygus/term_database_sygus.h" @@ -37,14 +38,14 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) { d_this = tn; Assert(tn.isDatatype()); - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); Assert(dt.isSygus()); Trace("sygus-db") << "Register type " << dt.getName() << "..." << std::endl; - TypeNode btn = TypeNode::fromType(dt.getSygusType()); + TypeNode btn = dt.getSygusType(); d_btype = btn; Assert(!d_btype.isNull()); // get the sygus variable list - Node var_list = Node::fromExpr(dt.getSygusVarList()); + Node var_list = dt.getSygusVarList(); if (!var_list.isNull()) { for (unsigned j = 0; j < var_list.getNumChildren(); j++) @@ -77,7 +78,7 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) { for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode ctn = TypeNode::fromType(dt[i].getArgType(j)); + TypeNode ctn = dt[i].getArgType(j); Trace("sygus-db") << " register subfield type " << ctn << std::endl; if (tds->registerSygusType(ctn)) { @@ -93,13 +94,12 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) // iterate over constructors for (unsigned i = 0; i < dt.getNumConstructors(); i++) { - Expr sop = dt[i].getSygusOp(); + Node sop = dt[i].getSygusOp(); Assert(!sop.isNull()); - Node n = Node::fromExpr(sop); Trace("sygus-db") << " Operator #" << i << " : " << sop; if (sop.getKind() == kind::BUILTIN) { - Kind sk = NodeManager::operatorToKind(n); + Kind sk = NodeManager::operatorToKind(sop); Trace("sygus-db") << ", kind = " << sk; d_kinds[sk] = i; d_arg_kind[i] = sk; @@ -112,8 +112,8 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) else if (sop.isConst() && dt[i].getNumArgs() == 0) { Trace("sygus-db") << ", constant"; - d_consts[n] = i; - d_arg_const[i] = n; + d_consts[sop] = i; + d_arg_const[i] = sop; } else if (sop.getKind() == LAMBDA) { @@ -121,9 +121,9 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) Assert(sop[0].getNumChildren() == dt[i].getNumArgs()); for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode ct = TypeNode::fromType(dt[i].getArgType(j)); + TypeNode ct = dt[i].getArgType(j); TypeNode cbt = tds->sygusToBuiltinType(ct); - TypeNode lat = TypeNode::fromType(sop[0][j].getType()); + TypeNode lat = sop[0][j].getType(); AlwaysAssert(cbt.isSubtypeOf(lat)) << "In sygus datatype " << dt.getName() << ", argument to a lambda constructor is not " << lat << std::endl; @@ -135,13 +135,13 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) } } // symbolic constructors - if (n.getAttribute(SygusAnyConstAttribute())) + if (sop.getAttribute(SygusAnyConstAttribute())) { d_sym_cons_any_constant = i; d_has_subterm_sym_cons = true; } - d_ops[n] = i; - d_arg_ops[i] = n; + d_ops[sop] = i; + d_arg_ops[i] = sop; Trace("sygus-db") << std::endl; // We must properly catch type errors in sygus grammars for arguments of // builtin operators. The challenge is that we easily ask for expected @@ -170,7 +170,7 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) csize = 1; for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode ct = TypeNode::fromType(dt[i].getArgType(j)); + TypeNode ct = dt[i].getArgType(j); if (ct == tn) { csize += d_min_term_size; @@ -182,7 +182,7 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn) } else { - Assert(!ct.isDatatype() || !ct.getDatatype().isSygus()); + Assert(!ct.isDatatype() || !ct.getDType().isSygus()); } } } @@ -219,12 +219,11 @@ void SygusTypeInfo::initializeVarSubclasses() std::vector rm_indices; TypeNode stn = sf_types[i]; Assert(stn.isDatatype()); - const Datatype& dt = stn.getDatatype(); + const DType& dt = stn.getDType(); for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++) { - Expr sop = dt[j].getSygusOp(); - Assert(!sop.isNull()); - Node sopn = Node::fromExpr(sop); + Node sopn = dt[j].getSygusOp(); + Assert(!sopn.isNull()); if (type_occurs.find(sopn) != type_occurs.end()) { // if it is a variable, store that it occurs in stn @@ -272,7 +271,7 @@ void SygusTypeInfo::computeMinTypeDepthInternal(TypeNode tn, // do not recurse to non-datatype types return; } - const Datatype& dt = tn.getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { // do not recurse to non-sygus datatype types @@ -284,7 +283,7 @@ void SygusTypeInfo::computeMinTypeDepthInternal(TypeNode tn, { for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++) { - TypeNode at = TypeNode::fromType(dt[i].getArgType(j)); + TypeNode at = dt[i].getArgType(j); computeMinTypeDepthInternal(at, type_depth + 1); } } diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index 834ca1975..10af0d703 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -14,6 +14,7 @@ #include "theory/quantifiers/sygus_sampler.h" +#include "expr/dtype.h" #include "expr/node_algorithm.h" #include "options/base_options.h" #include "options/quantifiers_options.h" @@ -92,7 +93,7 @@ void SygusSampler::initializeSygus(TermDbSygus* tds, d_is_valid = true; d_ftn = f.getType(); Assert(d_ftn.isDatatype()); - const Datatype& dt = static_cast(d_ftn.toType()).getDatatype(); + const DType& dt = d_ftn.getDType(); Assert(dt.isSygus()); Trace("sygus-sample") << "Register sampler for " << f << std::endl; @@ -105,7 +106,7 @@ void SygusSampler::initializeSygus(TermDbSygus* tds, d_rvalue_null_cindices.clear(); d_var_sygus_types.clear(); // get the sygus variable list - Node var_list = Node::fromExpr(dt.getSygusVarList()); + Node var_list = dt.getSygusVarList(); if (!var_list.isNull()) { for (const Node& sv : var_list) @@ -659,7 +660,7 @@ Node SygusSampler::getSygusRandomValue(TypeNode tn, { return getRandomValue(tn); } - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { return getRandomValue(tn); @@ -685,7 +686,7 @@ Node SygusSampler::getSygusRandomValue(TypeNode tn, << "Recurse constructor index #" << index << std::endl; unsigned cindex = cindices[index]; Assert(cindex < dt.getNumConstructors()); - const DatatypeConstructor& dtc = dt[cindex]; + const DTypeConstructor& dtc = dt[cindex]; // more likely to terminate in recursive calls double rchance_new = rchance + (1.0 - rchance) * rinc; std::map pre; @@ -718,7 +719,7 @@ Node SygusSampler::getSygusRandomValue(TypeNode tn, } Trace("sygus-sample-grammar") << "...resort to random value" << std::endl; // if we did not generate based on the grammar, pick a random value - return getRandomValue(TypeNode::fromType(dt.getSygusType())); + return getRandomValue(dt.getSygusType()); } // recursion depth bounded by number of types in grammar (small) @@ -731,15 +732,15 @@ void SygusSampler::registerSygusType(TypeNode tn) { return; } - const Datatype& dt = static_cast(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { return; } for (unsigned i = 0, ncons = dt.getNumConstructors(); i < ncons; i++) { - const DatatypeConstructor& dtc = dt[i]; - Node sop = Node::fromExpr(dtc.getSygusOp()); + const DTypeConstructor& dtc = dt[i]; + Node sop = dtc.getSygusOp(); bool isVar = std::find(d_vars.begin(), d_vars.end(), sop) != d_vars.end(); if (isVar) { diff --git a/src/theory/sets/rels_utils.h b/src/theory/sets/rels_utils.h index 8ce314c94..79757d311 100644 --- a/src/theory/sets/rels_utils.h +++ b/src/theory/sets/rels_utils.h @@ -17,6 +17,9 @@ #ifndef SRC_THEORY_SETS_RELS_UTILS_H_ #define SRC_THEORY_SETS_RELS_UTILS_H_ +#include "expr/dtype.h" +#include "expr/node.h" + namespace CVC4 { namespace theory { namespace sets { @@ -67,8 +70,9 @@ public: return tuple[n_th]; } TypeNode tn = tuple.getType(); - const Datatype& dt = tn.getDatatype(); - return NodeManager::currentNM()->mkNode(kind::APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal( tn.toType(), n_th ), tuple); + const DType& dt = tn.getDType(); + return NodeManager::currentNM()->mkNode( + kind::APPLY_SELECTOR_TOTAL, dt[0].getSelectorInternal(tn, n_th), tuple); } static Node reverseTuple( Node tuple ) { @@ -77,16 +81,17 @@ public: std::vector tuple_types = tuple.getType().getTupleTypes(); std::reverse( tuple_types.begin(), tuple_types.end() ); TypeNode tn = NodeManager::currentNM()->mkTupleType( tuple_types ); - const Datatype& dt = tn.getDatatype(); - elements.push_back( Node::fromExpr(dt[0].getConstructor() ) ); + const DType& dt = tn.getDType(); + elements.push_back(dt[0].getConstructor()); for(int i = tuple_types.size() - 1; i >= 0; --i) { elements.push_back( nthElementOfTuple(tuple, i) ); } return NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, elements ); } static Node constructPair(Node rel, Node a, Node b) { - const Datatype& dt = rel.getType().getSetElementType().getDatatype(); - return NodeManager::currentNM()->mkNode(kind::APPLY_CONSTRUCTOR, Node::fromExpr(dt[0].getConstructor()), a, b); + const DType& dt = rel.getType().getSetElementType().getDType(); + return NodeManager::currentNM()->mkNode( + kind::APPLY_CONSTRUCTOR, dt[0].getConstructor(), a, b); } }; diff --git a/src/theory/sets/theory_sets_rels.cpp b/src/theory/sets/theory_sets_rels.cpp index e8724aa8b..2f0997982 100644 --- a/src/theory/sets/theory_sets_rels.cpp +++ b/src/theory/sets/theory_sets_rels.cpp @@ -299,12 +299,12 @@ void TheorySetsRels::check(Theory::Effort level) } hasChecked.insert( fst_mem_rep ); - const Datatype& dt = - join_image_term.getType().getSetElementType().getDatatype(); - Node new_membership = NodeManager::currentNM()->mkNode(kind::MEMBER, - NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, - Node::fromExpr(dt[0].getConstructor()), fst_mem_rep ), - join_image_term); + const DType& dt = + join_image_term.getType().getSetElementType().getDType(); + Node new_membership = nm->mkNode( + MEMBER, + nm->mkNode(APPLY_CONSTRUCTOR, dt[0].getConstructor(), fst_mem_rep), + join_image_term); if (d_state.isEntailed(new_membership, true)) { ++mem_rep_it; @@ -429,9 +429,11 @@ void TheorySetsRels::check(Theory::Effort level) Node reason = exp; Node fst_mem = RelsUtils::nthElementOfTuple( exp[0], 0 ); Node snd_mem = RelsUtils::nthElementOfTuple( exp[0], 1 ); - const Datatype& dt = - iden_term[0].getType().getSetElementType().getDatatype(); - Node fact = NodeManager::currentNM()->mkNode( kind::MEMBER, NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, Node::fromExpr(dt[0].getConstructor()), fst_mem ), iden_term[0] ); + const DType& dt = iden_term[0].getType().getSetElementType().getDType(); + Node fact = nm->mkNode( + MEMBER, + nm->mkNode(APPLY_CONSTRUCTOR, dt[0].getConstructor(), fst_mem), + iden_term[0]); if( exp[1] != iden_term ) { reason = NodeManager::currentNM()->mkNode( kind::AND, reason, NodeManager::currentNM()->mkNode( kind::EQUAL, exp[1], iden_term ) ); @@ -767,18 +769,18 @@ void TheorySetsRels::check(Theory::Effort level) Node mem = exp[0]; std::vector r1_element; std::vector r2_element; - const Datatype& dt1 = pt_rel[0].getType().getSetElementType().getDatatype(); + const DType& dt1 = pt_rel[0].getType().getSetElementType().getDType(); unsigned int s1_len = pt_rel[0].getType().getSetElementType().getTupleLength(); unsigned int tup_len = pt_rel.getType().getSetElementType().getTupleLength(); - r1_element.push_back(Node::fromExpr(dt1[0].getConstructor())); + r1_element.push_back(dt1[0].getConstructor()); unsigned int i = 0; for(; i < s1_len; ++i) { r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); } - const Datatype& dt2 = pt_rel[1].getType().getSetElementType().getDatatype(); - r2_element.push_back(Node::fromExpr(dt2[0].getConstructor())); + const DType& dt2 = pt_rel[1].getType().getSetElementType().getDType(); + r2_element.push_back(dt2[0].getConstructor()); for(; i < tup_len; ++i) { r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); } @@ -825,20 +827,18 @@ void TheorySetsRels::check(Theory::Effort level) TypeNode shared_type = r2_rep.getType().getSetElementType().getTupleTypes()[0]; Node shared_x = d_state.getSkolemCache().mkTypedSkolemCached( shared_type, mem, join_rel, SkolemCache::SK_JOIN, "srj"); - const Datatype& dt1 = - join_rel[0].getType().getSetElementType().getDatatype(); + const DType& dt1 = join_rel[0].getType().getSetElementType().getDType(); unsigned int s1_len = join_rel[0].getType().getSetElementType().getTupleLength(); unsigned int tup_len = join_rel.getType().getSetElementType().getTupleLength(); unsigned int i = 0; - r1_element.push_back(Node::fromExpr(dt1[0].getConstructor())); + r1_element.push_back(dt1[0].getConstructor()); for(; i < s1_len-1; ++i) { r1_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); } r1_element.push_back(shared_x); - const Datatype& dt2 = - join_rel[1].getType().getSetElementType().getDatatype(); - r2_element.push_back(Node::fromExpr(dt2[0].getConstructor())); + const DType& dt2 = join_rel[1].getType().getSetElementType().getDType(); + r2_element.push_back(dt2[0].getConstructor()); r2_element.push_back(shared_x); for(; i < tup_len; ++i) { r2_element.push_back(RelsUtils::nthElementOfTuple(mem, i)); @@ -1041,7 +1041,7 @@ void TheorySetsRels::check(Theory::Effort level) TypeNode tn = rel.getType().getSetElementType(); Node r1_rmost = RelsUtils::nthElementOfTuple( r1_rep_exps[i][0], r1_tuple_len-1 ); Node r2_lmost = RelsUtils::nthElementOfTuple( r2_rep_exps[j][0], 0 ); - tuple_elements.push_back( Node::fromExpr(tn.getDatatype()[0].getConstructor()) ); + tuple_elements.push_back(tn.getDType()[0].getConstructor()); if( (areEqual(r1_rmost, r2_lmost) && rel.getKind() == kind::JOIN) || rel.getKind() == kind::PRODUCT ) { @@ -1226,7 +1226,7 @@ void TheorySetsRels::check(Theory::Effort level) if(d_symbolic_tuples.find(n) == d_symbolic_tuples.end()) { Trace("rels-debug") << "[Theory::Rels] Reduce tuple var: " << n[0] << " to a concrete one " << " node = " << n << std::endl; std::vector tuple_elements; - tuple_elements.push_back(Node::fromExpr((n[0].getType().getDatatype())[0].getConstructor())); + tuple_elements.push_back((n[0].getType().getDType())[0].getConstructor()); for(unsigned int i = 0; i < n[0].getType().getTupleLength(); i++) { Node element = RelsUtils::nthElementOfTuple(n[0], i); makeSharedTerm(element); diff --git a/src/theory/sets/theory_sets_rewriter.cpp b/src/theory/sets/theory_sets_rewriter.cpp index a7c506582..d8f5f8c4f 100644 --- a/src/theory/sets/theory_sets_rewriter.cpp +++ b/src/theory/sets/theory_sets_rewriter.cpp @@ -21,6 +21,8 @@ #include "theory/sets/normal_form.h" #include "theory/sets/rels_utils.h" +using namespace CVC4::kind; + namespace CVC4 { namespace theory { namespace sets { @@ -281,7 +283,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { while(left_it != left.end()) { Trace("rels-debug") << "Sets::postRewrite processing left_it = " << *left_it << std::endl; std::vector left_tuple; - left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + left_tuple.push_back(tn.getDType()[0].getConstructor()); for(int i = 0; i < left_len; i++) { left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i)); } @@ -324,7 +326,7 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { TypeNode tn = node.getType().getSetElementType(); while(left_it != left.end()) { std::vector left_tuple; - left_tuple.push_back(Node::fromExpr(tn.getDatatype()[0].getConstructor())); + left_tuple.push_back(tn.getDType()[0].getConstructor()); for(int i = 0; i < left_len - 1; i++) { left_tuple.push_back(RelsUtils::nthElementOfTuple(*left_it,i)); } @@ -431,8 +433,9 @@ RewriteResponse TheorySetsRewriter::postRewrite(TNode node) { ++rel_mems_it_snd; } if( existing_mems.size() >= min_card ) { - const Datatype& dt = node.getType().getSetElementType().getDatatype(); - join_img_mems.insert(NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, Node::fromExpr(dt[0].getConstructor()), fst_mem )); + const DType& dt = node.getType().getSetElementType().getDType(); + join_img_mems.insert( + nm->mkNode(APPLY_CONSTRUCTOR, dt[0].getConstructor(), fst_mem)); } ++rel_mems_it; } diff --git a/src/theory/theory_model_builder.cpp b/src/theory/theory_model_builder.cpp index 47355aa81..6df412ae3 100644 --- a/src/theory/theory_model_builder.cpp +++ b/src/theory/theory_model_builder.cpp @@ -13,6 +13,7 @@ **/ #include "theory/theory_model_builder.h" +#include "expr/dtype.h" #include "options/quantifiers_options.h" #include "options/smt_options.h" #include "options/uf_options.h" @@ -198,7 +199,7 @@ bool TheoryEngineModelBuilder::involvesUSort(TypeNode tn) } else if (tn.isDatatype()) { - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); return dt.involvesUninterpretedType(); } else @@ -264,12 +265,12 @@ void TheoryEngineModelBuilder::addToTypeList( } else if (tn.isDatatype()) { - const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + const DType& dt = tn.getDType(); for (unsigned i = 0; i < dt.getNumConstructors(); i++) { for (unsigned j = 0; j < dt[i].getNumArgs(); j++) { - TypeNode ctn = TypeNode::fromType(dt[i][j].getRangeType()); + TypeNode ctn = dt[i][j].getRangeType(); addToTypeList(ctn, type_list, visiting); } } @@ -627,10 +628,9 @@ bool TheoryEngineModelBuilder::buildModel(Model* m) bool isCorecursive = false; if (t.isDatatype()) { - const Datatype& dt = ((DatatypeType)(t).toType()).getDatatype(); - isCorecursive = - dt.isCodatatype() && (!dt.isFinite(t.toType()) - || dt.isRecursiveSingleton(t.toType())); + const DType& dt = t.getDType(); + isCorecursive = dt.isCodatatype() + && (!dt.isFinite(t) || dt.isRecursiveSingleton(t)); } #ifdef CVC4_ASSERTIONS bool isUSortFiniteRestricted = false;