From: Andrew Reynolds Date: Fri, 31 Jan 2020 16:43:36 +0000 (-0600) Subject: Update sygus grammar normalization to use node-level datatype. (#3567) X-Git-Tag: cvc5-1.0.0~3700 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=087ff3ef026440480eb7f72c75f0710b10192623;p=cvc5.git Update sygus grammar normalization to use node-level datatype. (#3567) --- diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index c7c1d820f..b2e7d2681 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -162,13 +162,15 @@ Node SygusGrammarNorm::TypeObject::eliminatePartialOperators(Node n) return visited[n]; } -void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm, - const DatatypeConstructor& cons) +void SygusGrammarNorm::TypeObject::addConsInfo( + SygusGrammarNorm* sygus_norm, + const DTypeConstructor& cons, + std::shared_ptr spc) { Trace("sygus-grammar-normalize") << "...for " << cons.getName() << "\n"; /* Recover the sygus operator to not lose reference to the original * operator (NOT, ITE, etc) */ - Node sygus_op = Node::fromExpr(cons.getSygusOp()); + Node sygus_op = cons.getSygusOp(); Trace("sygus-grammar-normalize-debug") << ".....operator is " << sygus_op << std::endl; Node exp_sop_n = sygus_op; @@ -208,11 +210,12 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm, } std::vector consTypes; - for (const DatatypeConstructorArg& arg : cons) + const std::vector >& args = cons.getArgs(); + for (const std::shared_ptr& arg : args) { // Collect unresolved type nodes corresponding to the typenode of the // arguments. - TypeNode atype = TypeNode::fromType(arg.getRangeType()); + TypeNode atype = arg->getRangeType(); // normalize it recursively atype = sygus_norm->normalizeSygusRec(atype); consTypes.push_back(atype); @@ -220,19 +223,16 @@ void SygusGrammarNorm::TypeObject::addConsInfo(SygusGrammarNorm* sygus_norm, Trace("sygus-type-cons-defs") << "\tOriginal op: " << cons.getSygusOp() << "\n\tExpanded one: " << exp_sop_n << "\n\n"; - d_sdt.addConstructor(exp_sop_n, - cons.getName(), - consTypes, - cons.getSygusPrintCallback(), - cons.getWeight()); + d_sdt.addConstructor( + exp_sop_n, cons.getName(), consTypes, spc, cons.getWeight()); } void SygusGrammarNorm::TypeObject::initializeDatatype( - SygusGrammarNorm* sygus_norm, const Datatype& dt) + SygusGrammarNorm* sygus_norm, const DType& dt) { /* Use the sygus type to not lose reference to the original types (Bool, * Int, etc) */ - TypeNode sygusType = TypeNode::fromType(dt.getSygusType()); + TypeNode sygusType = dt.getSygusType(); d_sdt.initializeDatatype(sygusType, sygus_norm->d_sygus_vars.toExpr(), dt.getSygusAllowConst(), @@ -247,7 +247,7 @@ void SygusGrammarNorm::TypeObject::initializeDatatype( void SygusGrammarNorm::TransfDrop::buildType(SygusGrammarNorm* sygus_norm, TypeObject& to, - const Datatype& dt, + const DType& dt, std::vector& op_pos) { std::vector difference; @@ -287,7 +287,7 @@ bool SygusGrammarNorm::TransfChain::isId(TypeNode tn, Node op, Node n) void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm, TypeObject& to, - const Datatype& dt, + const DType& dt, std::vector& op_pos) { NodeManager* nm = NodeManager::currentNM(); @@ -324,8 +324,7 @@ void SygusGrammarNorm::TransfChain::buildType(SygusGrammarNorm* sygus_norm, Trace("sygus-grammar-normalize-chain") << "\n"; } /* Build identity operator and empty callback */ - Node iden_op = - SygusGrammarNorm::getIdOp(TypeNode::fromType(dt.getSygusType())); + Node iden_op = SygusGrammarNorm::getIdOp(dt.getSygusType()); /* If all operators are claimed, create a monomial */ if (nb_op_pos == d_elem_pos.size() + 1) { @@ -398,10 +397,10 @@ std::map SygusGrammarNorm::d_tn_to_id = {}; * returns true if collected anything */ std::unique_ptr SygusGrammarNorm::inferTransf( - TypeNode tn, const Datatype& dt, const std::vector& op_pos) + TypeNode tn, const DType& dt, const std::vector& op_pos) { NodeManager* nm = NodeManager::currentNM(); - TypeNode sygus_tn = TypeNode::fromType(dt.getSygusType()); + TypeNode sygus_tn = dt.getSygusType(); Trace("sygus-gnorm") << "Infer transf for " << dt.getName() << "..." << std::endl; Trace("sygus-gnorm") << " #cons = " << op_pos.size() << " / " @@ -436,21 +435,20 @@ std::unique_ptr SygusGrammarNorm::inferTransf( for (unsigned i = 0, size = op_pos.size(); i < size; ++i) { Assert(op_pos[i] < dt.getNumConstructors()); - Expr sop = dt[op_pos[i]].getSygusOp(); + Node sop = dt[op_pos[i]].getSygusOp(); /* Collects a chainable operator such as PLUS */ - if (sop.getKind() == BUILTIN - && TransfChain::isChainable(sygus_tn, Node::fromExpr(sop))) + if (sop.getKind() == BUILTIN && TransfChain::isChainable(sygus_tn, sop)) { - Assert(nm->operatorToKind(Node::fromExpr(sop)) == PLUS); + Assert(nm->operatorToKind(sop) == PLUS); /* TODO #1304: be robust for this case */ /* For now only transforms applications whose arguments have the same type * as the root */ bool same_type_plus = true; - for (const DatatypeConstructorArg& arg : dt[op_pos[i]]) + const std::vector >& args = + dt[op_pos[i]].getArgs(); + for (const std::shared_ptr& arg : args) { - if (TypeNode::fromType( - static_cast(arg.getType()).getRangeType()) - != tn) + if (arg->getRangeType() != tn) { same_type_plus = false; break; @@ -472,7 +470,7 @@ std::unique_ptr SygusGrammarNorm::inferTransf( } /* TODO #1304: check this for each operator */ /* Collects elements that are not the identity (e.g. 0 is the id of PLUS) */ - if (!TransfChain::isId(sygus_tn, nm->operatorOf(PLUS), Node::fromExpr(sop))) + if (!TransfChain::isId(sygus_tn, nm->operatorOf(PLUS), sop)) { Trace("sygus-grammar-normalize-infer") << "\tCollecting for NON_ID_ELEMS the sop " << sop @@ -492,7 +490,7 @@ std::unique_ptr SygusGrammarNorm::inferTransf( } TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn, - const Datatype& dt, + const DType& dt, std::vector& op_pos) { Assert(tn.isDatatype()); @@ -541,7 +539,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn, if (dt.getSygusAllowConst()) { - TypeNode sygus_type = TypeNode::fromType(dt.getSygusType()); + TypeNode sygus_type = dt.getSygusType(); // must be handled by counterexample-guided instantiation // don't do it for Boolean (not worth the trouble, since it has only // minimal gain (1 any constant vs 2 constructors for true/false), and @@ -551,7 +549,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn, && !sygus_type.isBoolean()) { Trace("sygus-grammar-normalize") << "...add any constant constructor.\n"; - TypeNode dtn = TypeNode::fromType(dt.getSygusType()); + TypeNode dtn = dt.getSygusType(); // we add this constructor first since we use left associative chains // and our symmetry breaking should group any constants together // beneath the same application @@ -570,11 +568,15 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn, transformation->buildType(this, to, dt, op_pos); } - /* Remaining operators are rebuilt as they are */ + // Remaining operators are rebuilt as they are. + // Notice that we must extract the Datatype here to get the (Expr-layer) + // sygus print callback. + const Datatype& dtt = DatatypeType(tn.toType()).getDatatype(); for (unsigned i = 0, size = op_pos.size(); i < size; ++i) { - Assert(op_pos[i] < dt.getNumConstructors()); - to.addConsInfo(this, dt[op_pos[i]]); + unsigned oi = op_pos[i]; + Assert(oi < dt.getNumConstructors()); + to.addConsInfo(this, dt[oi], dtt[oi].getSygusPrintCallback()); } /* Build normalize datatype */ if (Trace.isOn("sygus-grammar-normalize")) @@ -599,7 +601,7 @@ TypeNode SygusGrammarNorm::normalizeSygusRec(TypeNode tn) return tn; } /* Collect all operators for normalization */ - const Datatype& dt = DatatypeType(tn.toType()).getDatatype(); + const DType& dt = tn.getDType(); if (!dt.isSygus()) { // datatype but not sygus datatype case diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.h b/src/theory/quantifiers/sygus/sygus_grammar_norm.h index f9c53c4ad..360762b38 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.h @@ -198,7 +198,8 @@ class SygusGrammarNorm * The types of the arguments of "cons" are recursively normalized */ void addConsInfo(SygusGrammarNorm* sygus_norm, - const DatatypeConstructor& cons); + const DTypeConstructor& cons, + std::shared_ptr spc); /** * Returns the total version of Kind k if it is a partial operator, or * otherwise k itself. @@ -219,7 +220,7 @@ class SygusGrammarNorm * The initialized datatype and its unresolved type are saved in the global * accumulators of "sygus_norm" */ - void initializeDatatype(SygusGrammarNorm* sygus_norm, const Datatype& dt); + void initializeDatatype(SygusGrammarNorm* sygus_norm, const DType& dt); //---------- information stored from original type node @@ -253,7 +254,7 @@ class SygusGrammarNorm */ virtual void buildType(SygusGrammarNorm* sygus_norm, TypeObject& to, - const Datatype& dt, + const DType& dt, std::vector& op_pos) = 0; }; /* class Transf */ @@ -271,7 +272,7 @@ class SygusGrammarNorm /** build type */ void buildType(SygusGrammarNorm* sygus_norm, TypeObject& to, - const Datatype& dt, + const DType& dt, std::vector& op_pos) override; private: @@ -329,7 +330,7 @@ class SygusGrammarNorm */ void buildType(SygusGrammarNorm* sygus_norm, TypeObject& to, - const Datatype& dt, + const DType& dt, std::vector& op_pos) override; /** Whether operator is chainable for the type (e.g. PLUS for Int) @@ -421,7 +422,7 @@ class SygusGrammarNorm * recursion depth is limited by the height of the types, which is small */ TypeNode normalizeSygusRec(TypeNode tn, - const Datatype& dt, + const DType& dt, std::vector& op_pos); /** wrapper for the above function @@ -436,7 +437,7 @@ class SygusGrammarNorm * TODO: #1304: Infer more complex transformations */ std::unique_ptr inferTransf(TypeNode tn, - const Datatype& dt, + const DType& dt, const std::vector& op_pos); }; /* class SygusGrammarNorm */