From bcd447593e30dd08c6dfc2e162505b9e815fd29b Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 10 Dec 2021 15:47:18 -0600 Subject: [PATCH] Refactor and fixes related to getSpecializedConstructorTerm (#7774) Fixes cvc5/cvc5-projects#381. --- src/api/cpp/cvc5.cpp | 8 +------ src/expr/dtype_cons.cpp | 18 ++++++++++---- src/expr/dtype_cons.h | 8 ++++++- src/theory/datatypes/datatypes_rewriter.cpp | 14 +++++++---- src/theory/datatypes/theory_datatypes.cpp | 2 +- .../datatypes/theory_datatypes_utils.cpp | 7 +----- src/theory/datatypes/type_enumerator.cpp | 8 ++----- .../quantifiers/cegqi/ceg_instantiator.cpp | 2 +- src/theory/quantifiers/quant_split.cpp | 2 +- .../quantifiers/quantifiers_rewriter.cpp | 6 ++--- src/theory/quantifiers/skolemize.cpp | 4 ++-- .../quantifiers/sygus/sygus_grammar_cons.cpp | 7 +++--- test/unit/api/cpp/solver_black.cpp | 24 +++++++++++++++++++ 13 files changed, 69 insertions(+), 41 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index e062e60ed..aa3e7fa3f 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -3815,13 +3815,7 @@ Term DatatypeConstructor::getSpecializedConstructorTerm( << "Cannot get specialized constructor type for non-datatype type " << retSort; //////// all checks before this line - - NodeManager* nm = d_solver->getNodeManager(); - Node ret = - nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType( - d_ctor->getSpecializedConstructorType(*retSort.d_type))), - d_ctor->getConstructor()); + Node ret = d_ctor->getInstantiatedConstructor(*retSort.d_type); (void)ret.getType(true); /* kick off type checking */ // apply type ascription to the operator Term sctor = api::Term(d_solver, ret); diff --git a/src/expr/dtype_cons.cpp b/src/expr/dtype_cons.cpp index 6ba3970c9..a054dffb8 100644 --- a/src/expr/dtype_cons.cpp +++ b/src/expr/dtype_cons.cpp @@ -83,6 +83,16 @@ Node DTypeConstructor::getConstructor() const return d_constructor; } +Node DTypeConstructor::getInstantiatedConstructor(TypeNode returnType) const +{ + Assert(isResolved()); + NodeManager* nm = NodeManager::currentNM(); + return nm->mkNode( + kind::APPLY_TYPE_ASCRIPTION, + nm->mkConst(AscriptionType(getInstantiatedConstructorType(returnType))), + d_constructor); +} + Node DTypeConstructor::getTester() const { Assert(isResolved()); @@ -116,12 +126,12 @@ unsigned DTypeConstructor::getWeight() const size_t DTypeConstructor::getNumArgs() const { return d_args.size(); } -TypeNode DTypeConstructor::getSpecializedConstructorType( +TypeNode DTypeConstructor::getInstantiatedConstructorType( TypeNode returnType) const { Assert(isResolved()); Assert(returnType.isDatatype()) - << "DTypeConstructor::getSpecializedConstructorType: expected datatype, " + << "DTypeConstructor::getInstantiatedConstructorType: expected datatype, " "got " << returnType; TypeNode ctn = d_constructor.getType(); @@ -439,7 +449,7 @@ Node DTypeConstructor::computeGroundTerm(TypeNode t, << ", ascribe to " << t << std::endl; groundTerms[0] = nm->mkNode( APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(getSpecializedConstructorType(t))), + nm->mkConst(AscriptionType(getInstantiatedConstructorType(t))), groundTerms[0]); groundTerm = nm->mkNode(APPLY_CONSTRUCTOR, groundTerms); } @@ -456,7 +466,7 @@ void DTypeConstructor::computeSharedSelectors(TypeNode domainType) const TypeNode ctype; if (domainType.isParametricDatatype()) { - ctype = getSpecializedConstructorType(domainType); + ctype = getInstantiatedConstructorType(domainType); } else { diff --git a/src/expr/dtype_cons.h b/src/expr/dtype_cons.h index a6268aad1..657f6b7b8 100644 --- a/src/expr/dtype_cons.h +++ b/src/expr/dtype_cons.h @@ -85,6 +85,12 @@ class DTypeConstructor * DType must be resolved. */ Node getConstructor() const; + /** + * Get the specialized constructor term of this constructor, which is + * the constructor wrapped in a APPLY_TYPE_ASCRIPTION. This is required + * for constructing applications of constructors for parametric datatypes. + */ + Node getInstantiatedConstructor(TypeNode returnType) const; /** * Get the tester operator of this constructor. The @@ -139,7 +145,7 @@ class DTypeConstructor * "cons" constructor type for lists of int---namely, * "int -> list[int] -> list[int]". */ - TypeNode getSpecializedConstructorType(TypeNode returnType) const; + TypeNode getInstantiatedConstructorType(TypeNode returnType) const; /** * Return the cardinality of this constructor (the product of the diff --git a/src/theory/datatypes/datatypes_rewriter.cpp b/src/theory/datatypes/datatypes_rewriter.cpp index 903a08bb4..b475d51e7 100644 --- a/src/theory/datatypes/datatypes_rewriter.cpp +++ b/src/theory/datatypes/datatypes_rewriter.cpp @@ -329,10 +329,7 @@ RewriteResponse DatatypesRewriter::preRewrite(TNode in) // get the constructor object const DTypeConstructor& dtc = utils::datatypeOf(op)[utils::indexOf(op)]; // create ascribed constructor type - Node tc = NodeManager::currentNM()->mkConst( - AscriptionType(dtc.getSpecializedConstructorType(tn))); - Node op_new = NodeManager::currentNM()->mkNode( - kind::APPLY_TYPE_ASCRIPTION, tc, op); + Node op_new = dtc.getInstantiatedConstructor(tn); // make new node std::vector children; children.push_back(op_new); @@ -891,7 +888,14 @@ TrustNode DatatypesRewriter::expandDefinition(Node n) size_t cindex = utils::cindexOf(op); const DTypeConstructor& dc = dt[cindex]; NodeBuilder b(APPLY_CONSTRUCTOR); - b << dc.getConstructor(); + if (tn.isParametricDatatype()) + { + b << dc.getInstantiatedConstructor(n[0].getType()); + } + else + { + b << dc.getConstructor(); + } Trace("dt-expand") << "Expand updater " << n << std::endl; Trace("dt-expand") << "expr is " << n << std::endl; Trace("dt-expand") << "updateIndex is " << updateIndex << std::endl; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index a9f0c3198..3f11ab1da 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -1244,7 +1244,7 @@ bool TheoryDatatypes::collectModelValues(TheoryModel* m, for( unsigned i=0; imkNode(APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(tspec)), - cchildren[0]); + cchildren[0] = dt[index].getInstantiatedConstructor(tn); } return nm->mkNode(APPLY_CONSTRUCTOR, cchildren); } diff --git a/src/theory/datatypes/type_enumerator.cpp b/src/theory/datatypes/type_enumerator.cpp index 6528f1052..69ebc9c78 100644 --- a/src/theory/datatypes/type_enumerator.cpp +++ b/src/theory/datatypes/type_enumerator.cpp @@ -143,11 +143,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ NodeBuilder b(kind::APPLY_CONSTRUCTOR); if (d_datatype.isParametric()) { - NodeManager* nm = NodeManager::currentNM(); - TypeNode typ = ctor.getSpecializedConstructorType(d_type); - b << nm->mkNode(kind::APPLY_TYPE_ASCRIPTION, - nm->mkConst(AscriptionType(typ)), - ctor.getConstructor()); + b << ctor.getInstantiatedConstructor(d_type); } else { @@ -245,7 +241,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){ TypeNode typ; if (d_datatype.isParametric()) { - typ = ctor.getSpecializedConstructorType(d_type); + typ = ctor.getInstantiatedConstructorType(d_type); } for (unsigned a = 0; a < ctor.getNumArgs(); ++a) { diff --git a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp index 9556d3f9c..ec33fe5fd 100644 --- a/src/theory/quantifiers/cegqi/ceg_instantiator.cpp +++ b/src/theory/quantifiers/cegqi/ceg_instantiator.cpp @@ -357,7 +357,7 @@ CegHandledStatus CegInstantiator::isCbqiSort( if (dt.isParametric()) { // if parametric, must instantiate the argument types - consType = dt[i].getSpecializedConstructorType(tn); + consType = dt[i].getInstantiatedConstructorType(tn); } else { diff --git a/src/theory/quantifiers/quant_split.cpp b/src/theory/quantifiers/quant_split.cpp index 55fa2a1e5..e6cee778b 100644 --- a/src/theory/quantifiers/quant_split.cpp +++ b/src/theory/quantifiers/quant_split.cpp @@ -167,7 +167,7 @@ void QuantDSplit::check(Theory::Effort e, QEffort quant_e) for (unsigned j = 0, ncons = dt.getNumConstructors(); j < ncons; j++) { std::vector vars; - TypeNode dtjtn = dt[j].getSpecializedConstructorType(tn); + TypeNode dtjtn = dt[j].getInstantiatedConstructorType(tn); Assert(dtjtn.getNumChildren() == dt[j].getNumArgs() + 1); for (unsigned k = 0, nargs = dt[j].getNumArgs(); k < nargs; k++) { diff --git a/src/theory/quantifiers/quantifiers_rewriter.cpp b/src/theory/quantifiers/quantifiers_rewriter.cpp index ba10a2efc..2002c73db 100644 --- a/src/theory/quantifiers/quantifiers_rewriter.cpp +++ b/src/theory/quantifiers/quantifiers_rewriter.cpp @@ -906,9 +906,9 @@ bool QuantifiersRewriter::getVarElimLit(Node body, // take into account if parametric if (dt.isParametric()) { - tspec = c.getSpecializedConstructorType(lit[0].getType()); - cons = nm->mkNode( - APPLY_TYPE_ASCRIPTION, nm->mkConst(AscriptionType(tspec)), cons); + TypeNode ltn = lit[0].getType(); + tspec = c.getInstantiatedConstructorType(ltn); + cons = c.getInstantiatedConstructor(ltn); } else { diff --git a/src/theory/quantifiers/skolemize.cpp b/src/theory/quantifiers/skolemize.cpp index f116b2f3c..9f2f9c91c 100644 --- a/src/theory/quantifiers/skolemize.cpp +++ b/src/theory/quantifiers/skolemize.cpp @@ -137,8 +137,8 @@ void Skolemize::getSelfSel(const DType& dt, TypeNode tspec; if (dt.isParametric()) { - tspec = dc.getSpecializedConstructorType(n.getType()); - Trace("sk-ind-debug") << "Specialized constructor type : " << tspec + tspec = dc.getInstantiatedConstructorType(n.getType()); + Trace("sk-ind-debug") << "Instantiated constructor type : " << tspec << std::endl; Assert(tspec.getNumChildren() == dc.getNumArgs()); } diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 438afbe82..95d3a5ab5 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -483,7 +483,7 @@ void CegGrammarConstructor::collectSygusGrammarTypesFor( { // get the specialized constructor type, which accounts for // parametric datatypes - TypeNode ctn = dt[i].getSpecializedConstructorType(range); + TypeNode ctn = dt[i].getInstantiatedConstructorType(range); std::vector argTypes = ctn.getArgTypes(); for (size_t j = 0, nargs = argTypes.size(); j < nargs; ++j) { @@ -1010,12 +1010,11 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( { Trace("sygus-grammar-def") << "...for " << dt[l].getName() << std::endl; Node cop = dt[l].getConstructor(); - TypeNode tspec = dt[l].getSpecializedConstructorType(types[i]); + TypeNode tspec = dt[l].getInstantiatedConstructorType(types[i]); // must specialize if a parametric datatype if (dt.isParametric()) { - cop = nm->mkNode( - APPLY_TYPE_ASCRIPTION, nm->mkConst(AscriptionType(tspec)), cop); + cop = dt[l].getInstantiatedConstructor(types[i]); } if (dt[l].getNumArgs() == 0) { diff --git a/test/unit/api/cpp/solver_black.cpp b/test/unit/api/cpp/solver_black.cpp index c268ee4f8..2df5de4b8 100644 --- a/test/unit/api/cpp/solver_black.cpp +++ b/test/unit/api/cpp/solver_black.cpp @@ -2750,5 +2750,29 @@ TEST_F(TestApiBlackSolver, getDatatypeArity) ASSERT_EQ(s3.getDatatypeArity(), 0); } +TEST_F(TestApiBlackSolver, proj_issue381) +{ + Sort s1 = d_solver.getBooleanSort(); + + Sort psort = d_solver.mkParamSort("_x9"); + DatatypeDecl dtdecl = d_solver.mkDatatypeDecl("_x8", psort); + DatatypeConstructorDecl ctor = d_solver.mkDatatypeConstructorDecl("_x22"); + ctor.addSelector("_x19", s1); + dtdecl.addConstructor(ctor); + Sort s3 = d_solver.mkDatatypeSort(dtdecl); + Sort s6 = s3.instantiate({s1}); + Term t26 = d_solver.mkConst(s6, "_x63"); + Term t5 = d_solver.mkTrue(); + Term t187 = d_solver.mkTerm(APPLY_UPDATER, + t26.getSort() + .getDatatype() + .getConstructor("_x22") + .getSelector("_x19") + .getUpdaterTerm(), + t26, + t5); + ASSERT_NO_THROW(d_solver.simplify(t187)); +} + } // namespace test } // namespace cvc5 -- 2.30.2