From ee689a569dfe33fc7245c5133ff5de5de57cd1d9 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 1 Apr 2022 13:16:43 -0500 Subject: [PATCH] Internal simplifications to constructing datatypes (#8519) In preparation for #8511. This makes it so that unresolved sorts are automatically inferred when constructing datatypes at the NodeManager level. This is in preparation for simplifying the API. Changes: (1) NodeManager is cleaned so that unresolved types are automatically inferred and are not a part of its public interface. Internal code generating datatypes is simplified as a result. (2) Adds necessary utilities to TypeNode and cleans an unused flag (3) The parser is cleaned to not track unresolved sorts in an ad-hoc manner. (4) The API is patched to use the simpler interface. --- src/api/cpp/cvc5.cpp | 20 ++--- src/expr/dtype.cpp | 46 ++++++++++ src/expr/dtype.h | 8 ++ src/expr/dtype_cons.cpp | 6 +- src/expr/dtype_cons.h | 2 +- src/expr/dtype_selector.cpp | 6 +- src/expr/node_manager_attributes.h | 7 ++ src/expr/node_manager_template.cpp | 29 +++++-- src/expr/node_manager_template.h | 83 +++++++++---------- src/expr/type_node.cpp | 5 ++ src/expr/type_node.h | 3 + src/parser/parser.cpp | 25 +----- src/parser/parser.h | 17 ---- src/preprocessing/passes/synth_rew_rules.cpp | 8 +- src/theory/datatypes/sygus_datatype_utils.cpp | 11 +-- src/theory/quantifiers/sygus/cegis_unif.cpp | 6 +- .../quantifiers/sygus/sygus_grammar_cons.cpp | 8 +- .../quantifiers/sygus/sygus_grammar_norm.cpp | 5 +- test/unit/node/type_cardinality_black.cpp | 2 +- test/unit/util/datatype_black.cpp | 24 ++---- 20 files changed, 170 insertions(+), 151 deletions(-) diff --git a/src/api/cpp/cvc5.cpp b/src/api/cpp/cvc5.cpp index f5d8f6ad4..beaa4887a 100644 --- a/src/api/cpp/cvc5.cpp +++ b/src/api/cpp/cvc5.cpp @@ -4419,7 +4419,9 @@ Sort Grammar::resolve() // make the unresolved type, used for referencing the final version of // the ntsymbol's datatype ntsToUnres[ntsymbol] = - Sort(d_solver, d_solver->getNodeManager()->mkSort(ntsymbol.toString())); + Sort(d_solver, + d_solver->getNodeManager()->mkUnresolvedDatatypeSort( + ntsymbol.toString())); } std::vector datatypes; @@ -4461,7 +4463,6 @@ Sort Grammar::resolve() std::vector datatypeTypes = d_solver->getNodeManager()->mkMutualDatatypeTypes( datatypes, - unresTypes, internal::NodeManager::DATATYPE_FLAG_PLACEHOLDER); // return is the first datatype @@ -5178,10 +5179,8 @@ std::vector Solver::mkDatatypeSortsInternal( datatypes.push_back(dtypedecls[i].getDatatype()); } - std::set utypes = - Sort::sortSetToTypeNodes(unresolvedSorts); std::vector dtypes = - getNodeManager()->mkMutualDatatypeTypes(datatypes, utypes); + getNodeManager()->mkMutualDatatypeTypes(datatypes); std::vector retTypes = Sort::typeNodeVectorToSorts(this, dtypes); return retTypes; } @@ -5544,10 +5543,7 @@ Sort Solver::mkParamSort(const std::optional& symbol) const //////// all checks before this line internal::TypeNode tn = - symbol ? getNodeManager()->mkSort( - *symbol, internal::NodeManager::SORT_FLAG_PLACEHOLDER) - : getNodeManager()->mkSort( - internal::NodeManager::SORT_FLAG_PLACEHOLDER); + symbol ? getNodeManager()->mkSort(*symbol) : getNodeManager()->mkSort(); return Sort(this, tn); //////// CVC5_API_TRY_CATCH_END; @@ -5633,11 +5629,7 @@ Sort Solver::mkUnresolvedSort(const std::string& symbol, size_t arity) const { CVC5_API_TRY_CATCH_BEGIN; //////// all checks before this line - if (arity) - { - return Sort(this, getNodeManager()->mkSortConstructor(symbol, arity)); - } - return Sort(this, getNodeManager()->mkSort(symbol)); + return Sort(this, getNodeManager()->mkUnresolvedDatatypeSort(symbol, arity)); //////// CVC5_API_TRY_CATCH_END; } diff --git a/src/expr/dtype.cpp b/src/expr/dtype.cpp index f2329ed2e..6af405fdc 100644 --- a/src/expr/dtype.cpp +++ b/src/expr/dtype.cpp @@ -143,6 +143,52 @@ size_t DType::cindexOfInternal(Node item) return item.getAttribute(DTypeConsIndexAttr()); } +void DType::collectUnresolvedDatatypeTypes(std::set& unresTypes) const +{ + // Scan the arguments of all constructors and collect their types. To be + // robust to datatypes with nested recursion, we collect the *component* + // types of all subfield types and store them in csfTypes. In other words, we + // search for unresolved datatypes that occur possibly as parameters to + // other parametric types. + std::unordered_set csfTypes; + for (const std::shared_ptr& ctor : d_constructors) + { + for (size_t i = 0, nargs = ctor->getNumArgs(); i < nargs; i++) + { + Node sel = (*ctor)[i].d_selector; + if (sel.isNull()) + { + // we currently permit null selector for representing self selectors, + // skip these. + continue; + } + // The selector has *not* been initialized to a variable of selector type, + // which is done during resolve. Instead, we get the raw type of sel + // and compute its component types. + expr::getComponentTypes(sel.getType(), csfTypes); + } + } + // Now, process each component type + for (const TypeNode& arg : csfTypes) + { + if (arg.isUnresolvedDatatype()) + { + // it is an unresolved datatype + unresTypes.insert(arg); + } + else if (arg.isInstantiatedUninterpretedSort()) + { + // it might be an instantiated sort constructor corresponding to a + // unresolved parametric datatype, in which case we extract its operator + TypeNode argc = arg.getUninterpretedSortConstructor(); + if (argc.isUnresolvedDatatype()) + { + unresTypes.insert(argc); + } + } + } +} + bool DType::resolve(const std::map& resolutions, const std::vector& placeholders, const std::vector& replacements, diff --git a/src/expr/dtype.h b/src/expr/dtype.h index 0fe1b902b..596a2acad 100644 --- a/src/expr/dtype.h +++ b/src/expr/dtype.h @@ -435,6 +435,14 @@ class DType void toStream(std::ostream& out) const; private: + /** + * Collect unresolved datatype types. This is called by NodeManager when + * constructing datatypes from datatype declarations. This adds all + * unresolved datatype types to unresTypes, which are then considered + * when constructing the datatype (for details, see + * NodeManager::mkMutualDatatypeTypesInternal). + */ + void collectUnresolvedDatatypeTypes(std::set& unresTypes) const; /** * DTypes refer to themselves, recursively, and we have a * chicken-and-egg problem. The TypeNode around the DType diff --git a/src/expr/dtype_cons.cpp b/src/expr/dtype_cons.cpp index 287dc0f2a..2c779b4fd 100644 --- a/src/expr/dtype_cons.cpp +++ b/src/expr/dtype_cons.cpp @@ -40,17 +40,17 @@ DTypeConstructor::DTypeConstructor(std::string name, Assert(name != ""); } -void DTypeConstructor::addArg(std::string selectorName, TypeNode selectorType) +void DTypeConstructor::addArg(std::string selectorName, TypeNode rangeType) { // We don't want to introduce a new data member, because eventually // we're going to be a constant stuffed inside a node. So we stow // the selector type away inside a var until resolution (when we can // create the proper selector type) Assert(!isResolved()); - Assert(!selectorType.isNull()); + Assert(!rangeType.isNull()); SkolemManager* sm = NodeManager::currentNM()->getSkolemManager(); Node sel = sm->mkDummySkolem("unresolved_" + selectorName, - selectorType, + rangeType, "is an unresolved selector type placeholder", SkolemManager::SKOLEM_EXACT_NAME); // can use null updater for now diff --git a/src/expr/dtype_cons.h b/src/expr/dtype_cons.h index b137d947c..3b0fa3571 100644 --- a/src/expr/dtype_cons.h +++ b/src/expr/dtype_cons.h @@ -57,7 +57,7 @@ class DTypeConstructor * to this constructor. Selector names need not be unique; * they are for convenience and pretty-printing only. */ - void addArg(std::string selectorName, TypeNode selectorType); + void addArg(std::string selectorName, TypeNode rangeType); /** * Add an argument, given a pointer to a selector object. */ diff --git a/src/expr/dtype_selector.cpp b/src/expr/dtype_selector.cpp index d60c682d0..068168987 100644 --- a/src/expr/dtype_selector.cpp +++ b/src/expr/dtype_selector.cpp @@ -44,7 +44,11 @@ Node DTypeSelector::getConstructor() const return d_constructor; } -TypeNode DTypeSelector::getType() const { return d_selector.getType(); } +TypeNode DTypeSelector::getType() const +{ + Assert(!d_selector.isNull()); + return d_selector.getType(); +} TypeNode DTypeSelector::getRangeType() const { diff --git a/src/expr/node_manager_attributes.h b/src/expr/node_manager_attributes.h index 72f33977a..c9bbe47bd 100644 --- a/src/expr/node_manager_attributes.h +++ b/src/expr/node_manager_attributes.h @@ -30,6 +30,9 @@ namespace attr { struct SortArityTag { }; struct TypeTag { }; struct TypeCheckedTag { }; + struct UnresolvedDatatypeTag + { + }; } // namespace attr typedef Attribute VarNameAttr; @@ -37,5 +40,9 @@ typedef Attribute SortArityAttr; typedef expr::Attribute TypeAttr; typedef expr::Attribute TypeCheckedAttr; +/** Attribute is true for unresolved datatype sorts */ +using UnresolvedDatatypeAttr = + expr::Attribute; + } // namespace expr } // namespace cvc5::internal diff --git a/src/expr/node_manager_template.cpp b/src/expr/node_manager_template.cpp index a1fdc48fb..9a5d575b2 100644 --- a/src/expr/node_manager_template.cpp +++ b/src/expr/node_manager_template.cpp @@ -571,10 +571,15 @@ std::vector NodeManager::mkMutualDatatypeTypes( const std::vector& datatypes, uint32_t flags) { std::set unresolvedTypes; - return mkMutualDatatypeTypes(datatypes, unresolvedTypes, flags); + // scan the list of datatypes to find unresolved datatypes + for (const DType& dt : datatypes) + { + dt.collectUnresolvedDatatypeTypes(unresolvedTypes); + } + return mkMutualDatatypeTypesInternal(datatypes, unresolvedTypes, flags); } -std::vector NodeManager::mkMutualDatatypeTypes( +std::vector NodeManager::mkMutualDatatypeTypesInternal( const std::vector& datatypes, const std::set& unresolvedTypes, uint32_t flags) @@ -880,7 +885,7 @@ void NodeManager::reclaimZombiesUntil(uint32_t k) size_t NodeManager::poolSize() const { return d_nodeValuePool.size(); } -TypeNode NodeManager::mkSort(uint32_t flags) +TypeNode NodeManager::mkSort() { NodeBuilder nb(this, kind::SORT_TYPE); Node sortTag = NodeBuilder(this, kind::SORT_TAG); @@ -888,7 +893,7 @@ TypeNode NodeManager::mkSort(uint32_t flags) return nb.constructTypeNode(); } -TypeNode NodeManager::mkSort(const std::string& name, uint32_t flags) +TypeNode NodeManager::mkSort(const std::string& name) { NodeBuilder nb(this, kind::SORT_TYPE); Node sortTag = NodeBuilder(this, kind::SORT_TAG); @@ -899,8 +904,7 @@ TypeNode NodeManager::mkSort(const std::string& name, uint32_t flags) } TypeNode NodeManager::mkSort(TypeNode constructor, - const std::vector& children, - uint32_t flags) + const std::vector& children) { Assert(constructor.getKind() == kind::SORT_TYPE && constructor.getNumChildren() == 0) @@ -922,9 +926,7 @@ TypeNode NodeManager::mkSort(TypeNode constructor, return type; } -TypeNode NodeManager::mkSortConstructor(const std::string& name, - size_t arity, - uint32_t flags) +TypeNode NodeManager::mkSortConstructor(const std::string& name, size_t arity) { Assert(arity > 0); NodeBuilder nb(this, kind::SORT_TYPE); @@ -936,6 +938,15 @@ TypeNode NodeManager::mkSortConstructor(const std::string& name, return type; } +TypeNode NodeManager::mkUnresolvedDatatypeSort(const std::string& name, + size_t arity) +{ + TypeNode usort = arity > 0 ? mkSortConstructor(name, arity) : mkSort(name); + // mark that it is an unresolved sort + setAttribute(usort, expr::UnresolvedDatatypeAttr(), true); + return usort; +} + Node NodeManager::mkVar(const std::string& name, const TypeNode& type) { Node n = NodeBuilder(this, kind::VARIABLE); diff --git a/src/expr/node_manager_template.h b/src/expr/node_manager_template.h index 49e88ca1b..2d384f0d5 100644 --- a/src/expr/node_manager_template.h +++ b/src/expr/node_manager_template.h @@ -502,40 +502,6 @@ class NodeManager std::vector mkMutualDatatypeTypes( const std::vector& datatypes, uint32_t flags = DATATYPE_FLAG_NONE); - /** - * Make a set of types representing the given datatypes, which may - * be mutually recursive. unresolvedTypes is a set of SortTypes - * that were used as placeholders in the Datatypes for the Datatypes - * of the same name. This is just a more complicated version of the - * above mkMutualDatatypeTypes() function, but is required to handle - * complex types. - * - * For example, unresolvedTypes might contain the single sort "list" - * (with that name reported from SortType::getName()). The - * datatypes list might have the single datatype - * - * DATATYPE - * list = cons(car:ARRAY INT OF list, cdr:list) | nil; - * END; - * - * To represent the Type of the array, the user had to create a - * placeholder type (an uninterpreted sort) to stand for "list" in - * the type of "car". It is this placeholder sort that should be - * passed in unresolvedTypes. If the datatype was of the simpler - * form: - * - * DATATYPE - * list = cons(car:list, cdr:list) | nil; - * END; - * - * then no complicated Type needs to be created, and the above, - * simpler form of mkMutualDatatypeTypes() is enough. - */ - std::vector mkMutualDatatypeTypes( - const std::vector& datatypes, - const std::set& unresolvedTypes, - uint32_t flags = DATATYPE_FLAG_NONE); - /** * Make a type representing a constructor with the given argument (subfield) * types and return type range. @@ -787,22 +753,55 @@ class NodeManager TypeNode mkTypeNode(Kind kind, const std::vector& children); /** Make a new (anonymous) sort of arity 0. */ - TypeNode mkSort(uint32_t flags = SORT_FLAG_NONE); + TypeNode mkSort(); /** Make a new sort with the given name of arity 0. */ - TypeNode mkSort(const std::string& name, uint32_t flags = SORT_FLAG_NONE); + TypeNode mkSort(const std::string& name); /** Make a new sort by parameterizing the given sort constructor. */ - TypeNode mkSort(TypeNode constructor, - const std::vector& children, - uint32_t flags = SORT_FLAG_NONE); + TypeNode mkSort(TypeNode constructor, const std::vector& children); /** Make a new sort with the given name and arity. */ - TypeNode mkSortConstructor(const std::string& name, - size_t arity, - uint32_t flags = SORT_FLAG_NONE); + TypeNode mkSortConstructor(const std::string& name, size_t arity); + + /** Make an unresolved datatype sort */ + TypeNode mkUnresolvedDatatypeSort(const std::string& name, size_t arity = 0); private: + /** + * Make a set of types representing the given datatypes, which may + * be mutually recursive. unresolvedTypes is a set of SortTypes + * that were used as placeholders in the Datatypes for the Datatypes + * of the same name. This is just a more complicated version of the + * above mkMutualDatatypeTypes() function, but is required to handle + * complex types. + * + * For example, unresolvedTypes might contain the single sort "list" + * (with that name reported from SortType::getName()). The + * datatypes list might have the single datatype + * + * DATATYPE + * list = cons(car:ARRAY INT OF list, cdr:list) | nil; + * END; + * + * To represent the Type of the array, the user had to create a + * placeholder type (an uninterpreted sort) to stand for "list" in + * the type of "car". It is this placeholder sort that should be + * passed in unresolvedTypes. If the datatype was of the simpler + * form: + * + * DATATYPE + * list = cons(car:list, cdr:list) | nil; + * END; + * + * then no complicated Type needs to be created, and the above, + * simpler form of mkMutualDatatypeTypes() is enough. + */ + std::vector mkMutualDatatypeTypesInternal( + const std::vector& datatypes, + const std::set& unresolvedTypes, + uint32_t flags = DATATYPE_FLAG_NONE); + typedef std::unordered_set diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index 1278e9bf2..a2a7ec360 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -464,6 +464,11 @@ uint64_t TypeNode::getUninterpretedSortConstructorArity() const return getAttribute(expr::SortArityAttr()); } +bool TypeNode::isUnresolvedDatatype() const +{ + return getAttribute(expr::UnresolvedDatatypeAttr()); +} + std::string TypeNode::getName() const { Assert(isUninterpretedSort() || isUninterpretedSortConstructor()); diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 8f844dfb0..ed7aedb99 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -672,6 +672,9 @@ private: /** Get sort constructor arity. */ uint64_t getUninterpretedSortConstructorArity() const; + /** Is this an unresolved datatype? */ + bool isUnresolvedDatatype() const; + /** * Get name, for uninterpreted sorts and uninterpreted sort constructors. */ diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 5dacfc8a8..22593e274 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -324,19 +324,16 @@ cvc5::Sort Parser::mkSortConstructor(const std::string& name, size_t arity) cvc5::Sort Parser::mkUnresolvedType(const std::string& name) { - cvc5::Sort unresolved = d_solver->mkUninterpretedSort(name); + cvc5::Sort unresolved = d_solver->mkUnresolvedSort(name); defineType(name, unresolved); - d_unresolved.insert(unresolved); return unresolved; } cvc5::Sort Parser::mkUnresolvedTypeConstructor(const std::string& name, size_t arity) { - cvc5::Sort unresolved = - d_solver->mkUninterpretedSortConstructorSort(arity, name); + cvc5::Sort unresolved = d_solver->mkUnresolvedSort(name, arity); defineType(name, vector(arity), unresolved); - d_unresolved.insert(unresolved); return unresolved; } @@ -346,10 +343,9 @@ cvc5::Sort Parser::mkUnresolvedTypeConstructor( Trace("parser") << "newSortConstructor(P)(" << name << ", " << params.size() << ")" << std::endl; cvc5::Sort unresolved = - d_solver->mkUninterpretedSortConstructorSort(params.size(), name); + d_solver->mkUnresolvedSort(name, params.size()); defineType(name, params, unresolved); cvc5::Sort t = getSort(name, params); - d_unresolved.insert(unresolved); return unresolved; } @@ -362,19 +358,11 @@ cvc5::Sort Parser::mkUnresolvedType(const std::string& name, size_t arity) return mkUnresolvedTypeConstructor(name, arity); } -bool Parser::isUnresolvedType(const std::string& name) { - if (!isDeclared(name, SYM_SORT)) { - return false; - } - return d_unresolved.find(getSort(name)) != d_unresolved.end(); -} - std::vector Parser::bindMutualDatatypeTypes( std::vector& datatypes, bool doOverload) { try { - std::vector types = - d_solver->mkDatatypeSorts(datatypes, d_unresolved); + std::vector types = d_solver->mkDatatypeSorts(datatypes); Assert(datatypes.size() == types.size()); @@ -442,11 +430,6 @@ std::vector Parser::bindMutualDatatypeTypes( } } - // These are no longer used, and the ExprManager would have - // complained of a bad substitution if anything is left unresolved. - // Clear out the set. - d_unresolved.clear(); - // throw exception if any datatype is not well-founded for (unsigned i = 0; i < datatypes.size(); ++i) { const cvc5::Datatype& dt = types[i].getDatatype(); diff --git a/src/parser/parser.h b/src/parser/parser.h index de7e029fa..4034f5efa 100644 --- a/src/parser/parser.h +++ b/src/parser/parser.h @@ -164,15 +164,6 @@ private: /** The set of attributes already warned about. */ std::set d_attributesWarnedAbout; - /** - * The current set of unresolved types. We can get by with this NOT - * being on the scope, because we can only have one DATATYPE - * definition going on at one time. This is a bit hackish; we - * depend on mkMutualDatatypeTypes() to check everything and clear - * this out. - */ - std::set d_unresolved; - /** * "Preemption commands": extra commands implied by subterms that * should be issued before the currently-being-parsed command is @@ -220,9 +211,6 @@ public: /** Get the associated input. */ Input* getInput() const { return d_input.get(); } - /** Get unresolved sorts */ - inline std::set& getUnresolvedSorts() { return d_unresolved; } - /** Deletes and replaces the current parser input. */ void setInput(Input* input) { d_input.reset(input); @@ -508,11 +496,6 @@ public: */ cvc5::Sort mkUnresolvedType(const std::string& name, size_t arity); - /** - * Returns true IFF name is an unresolved type. - */ - bool isUnresolvedType(const std::string& name); - /** * Creates and binds sorts of a list of mutually-recursive datatype * declarations. diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp index ab8950b76..741aafcbc 100644 --- a/src/preprocessing/passes/synth_rew_rules.cpp +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -263,7 +263,6 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( Trace("srs-input") << "Construct unresolved types..." << std::endl; // each canonical subterm corresponds to a grammar type - std::set unres; std::vector sdts; // make unresolved types for each canonical term std::map cterm_to_utype; @@ -273,9 +272,8 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( std::stringstream ss; ss << "T" << i; std::string tname = ss.str(); - TypeNode tnu = nm->mkSort(tname, NodeManager::SORT_FLAG_PLACEHOLDER); + TypeNode tnu = nm->mkUnresolvedDatatypeSort(tname); cterm_to_utype[ct] = tnu; - unres.insert(tnu); sdts.push_back(SygusDatatype(tname)); } Trace("srs-input") << "...finished." << std::endl; @@ -398,9 +396,9 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( datatypes.push_back(sdts[i].getDatatype()); } std::vector types = nm->mkMutualDatatypeTypes( - datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + datatypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); Trace("srs-input") << "...finished." << std::endl; - Assert(types.size() == unres.size()); + Assert(types.size() == datatypes.size()); std::map subtermTypes; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) { diff --git a/src/theory/datatypes/sygus_datatype_utils.cpp b/src/theory/datatypes/sygus_datatype_utils.cpp index 4379a3837..f2d0827c0 100644 --- a/src/theory/datatypes/sygus_datatype_utils.cpp +++ b/src/theory/datatypes/sygus_datatype_utils.cpp @@ -457,7 +457,6 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, // must convert all constructors to version with variables in "vars" std::vector sdts; - std::set unres; Trace("dtsygus-gen-debug") << "Process sygus type:" << std::endl; Trace("dtsygus-gen-debug") << sdtd.getName() << std::endl; @@ -469,9 +468,7 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, dtToProcess.push_back(sdt); std::stringstream ssutn0; ssutn0 << sdtd.getName() << "_s"; - TypeNode abdTNew = - nm->mkSort(ssutn0.str(), NodeManager::SORT_FLAG_PLACEHOLDER); - unres.insert(abdTNew); + TypeNode abdTNew = nm->mkUnresolvedDatatypeSort(ssutn0.str()); dtProcessed[sdt] = abdTNew; // We must convert all symbols in the sygus datatype type sdt to @@ -511,11 +508,9 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, { std::stringstream ssutn; ssutn << argt.getDType().getName() << "_s"; - argtNew = - nm->mkSort(ssutn.str(), NodeManager::SORT_FLAG_PLACEHOLDER); + argtNew = nm->mkUnresolvedDatatypeSort(ssutn.str()); Trace("dtsygus-gen-debug") << " ...unresolved type " << argtNew << " for " << argt << std::endl; - unres.insert(argtNew); dtProcessed[argt] = argtNew; dtNextToProcess.push_back(argt); } @@ -552,7 +547,7 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt, } // make the datatype types std::vector datatypeTypes = nm->mkMutualDatatypeTypes( - datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + datatypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); TypeNode sdtS = datatypeTypes[0]; if (TraceIsOn("dtsygus-gen-debug")) { diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index d8a8e3e53..194bf48a1 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -476,9 +476,7 @@ Node CegisUnifEnumDecisionStrategy::mkLiteral(unsigned n) Node bvl; std::string veName("_virtual_enum_grammar"); SygusDatatype sdt(veName); - TypeNode u = nm->mkSort(veName, NodeManager::SORT_FLAG_PLACEHOLDER); - std::set unresolvedTypes; - unresolvedTypes.insert(u); + TypeNode u = nm->mkUnresolvedDatatypeSort(veName); std::vector cargsEmpty; Node cr = nm->mkConstInt(Rational(1)); sdt.addConstructor(cr, "1", cargsEmpty); @@ -490,7 +488,7 @@ Node CegisUnifEnumDecisionStrategy::mkLiteral(unsigned n) std::vector datatypes; datatypes.push_back(sdt.getDatatype()); std::vector dtypes = nm->mkMutualDatatypeTypes( - datatypes, unresolvedTypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + datatypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); d_virtual_enum = sm->mkDummySkolem("_ve", dtypes[0]); d_tds->registerEnumerator( d_virtual_enum, Node::null(), d_parent, ROLE_ENUM_CONSTRAINED); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 3d9c67b92..2e42ce473 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -395,8 +395,8 @@ Node CegGrammarConstructor::convertToEmbedding(Node n) TypeNode CegGrammarConstructor::mkUnresolvedType(const std::string& name, std::set& unres) { - TypeNode unresolved = NodeManager::currentNM()->mkSort( - name, NodeManager::SORT_FLAG_PLACEHOLDER); + TypeNode unresolved = + NodeManager::currentNM()->mkUnresolvedDatatypeSort(name); unres.insert(unresolved); return unresolved; } @@ -1545,7 +1545,7 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType( Trace("sygus-grammar-def") << "...made " << datatypes.size() << " datatypes, now make mutual datatype types..." << std::endl; Assert(!datatypes.empty()); std::vector types = NodeManager::currentNM()->mkMutualDatatypeTypes( - datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + datatypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); Trace("sygus-grammar-def") << "...finished" << std::endl; Assert(types.size() == datatypes.size()); return types[0]; @@ -1592,7 +1592,7 @@ TypeNode CegGrammarConstructor::mkSygusTemplateTypeRec( Node templ, Node templ_a } std::vector types = NodeManager::currentNM()->mkMutualDatatypeTypes( - datatypes, unres, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + datatypes, NodeManager::DATATYPE_FLAG_PLACEHOLDER); Assert(types.size() == 1); return types[0]; } diff --git a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp index 02a01d73f..b0c65fd5c 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_norm.cpp @@ -58,8 +58,7 @@ bool OpPosTrie::getOrMakeType(TypeNode tn, { ss << "_" << std::to_string(op_pos[i]); } - d_unres_tn = NodeManager::currentNM()->mkSort( - ss.str(), NodeManager::SORT_FLAG_PLACEHOLDER); + d_unres_tn = NodeManager::currentNM()->mkUnresolvedDatatypeSort(ss.str()); Trace("sygus-grammar-normalize-trie") << "\tCreating type " << d_unres_tn << "\n"; unres_tn = d_unres_tn; @@ -528,7 +527,7 @@ TypeNode SygusGrammarNorm::normalizeSygusType(TypeNode tn, Node sygus_vars) } Assert(d_dt_all.size() == d_unres_t_all.size()); std::vector types = NodeManager::currentNM()->mkMutualDatatypeTypes( - d_dt_all, d_unres_t_all, NodeManager::DATATYPE_FLAG_PLACEHOLDER); + d_dt_all, NodeManager::DATATYPE_FLAG_PLACEHOLDER); Assert(types.size() == d_dt_all.size()); /* Clear accumulators */ d_dt_all.clear(); diff --git a/test/unit/node/type_cardinality_black.cpp b/test/unit/node/type_cardinality_black.cpp index 858ca3790..84bf0c788 100644 --- a/test/unit/node/type_cardinality_black.cpp +++ b/test/unit/node/type_cardinality_black.cpp @@ -307,7 +307,7 @@ TEST_F(TestNodeBlackTypeCardinality, ternary_functions) TEST_F(TestNodeBlackTypeCardinality, undefined_sorts) { - TypeNode foo = d_nodeManager->mkSort("foo", NodeManager::SORT_FLAG_NONE); + TypeNode foo = d_nodeManager->mkSort("foo"); // We've currently assigned them a specific Beth number, which // isn't really correct, but... ASSERT_FALSE(foo.getCardinality().isFinite()); diff --git a/test/unit/util/datatype_black.cpp b/test/unit/util/datatype_black.cpp index 1992c4154..b3ad83f83 100644 --- a/test/unit/util/datatype_black.cpp +++ b/test/unit/util/datatype_black.cpp @@ -275,13 +275,8 @@ TEST_F(TestUtilBlackDatatype, mutual_list_trees1) * list = cons(car: tree, cdr: list) | nil * END; */ - std::set unresolvedTypes; - TypeNode unresList = - d_nodeManager->mkSort("list", NodeManager::SORT_FLAG_PLACEHOLDER); - unresolvedTypes.insert(unresList); - TypeNode unresTree = - d_nodeManager->mkSort("tree", NodeManager::SORT_FLAG_PLACEHOLDER); - unresolvedTypes.insert(unresTree); + TypeNode unresList = d_nodeManager->mkSort("list"); + TypeNode unresTree = d_nodeManager->mkSort("tree"); DType tree("tree"); std::shared_ptr node = @@ -316,8 +311,7 @@ TEST_F(TestUtilBlackDatatype, mutual_list_trees1) std::vector dts; dts.push_back(tree); dts.push_back(list); - std::vector dtts = - d_nodeManager->mkMutualDatatypeTypes(dts, unresolvedTypes); + std::vector dtts = d_nodeManager->mkMutualDatatypeTypes(dts); ASSERT_TRUE(dtts[0].getDType().isResolved()); ASSERT_TRUE(dtts[1].getDType().isResolved()); @@ -345,13 +339,8 @@ TEST_F(TestUtilBlackDatatype, mutual_list_trees1) TEST_F(TestUtilBlackDatatype, mutual_list_trees2) { - std::set unresolvedTypes; - TypeNode unresList = - d_nodeManager->mkSort("list", NodeManager::SORT_FLAG_PLACEHOLDER); - unresolvedTypes.insert(unresList); - TypeNode unresTree = - d_nodeManager->mkSort("tree", NodeManager::SORT_FLAG_PLACEHOLDER); - unresolvedTypes.insert(unresTree); + TypeNode unresList = d_nodeManager->mkSort("list"); + TypeNode unresTree = d_nodeManager->mkSort("tree"); DType tree("tree"); std::shared_ptr node = @@ -386,8 +375,7 @@ TEST_F(TestUtilBlackDatatype, mutual_list_trees2) dts.push_back(tree); dts.push_back(list); // remake the types - std::vector dtts2 = - d_nodeManager->mkMutualDatatypeTypes(dts, unresolvedTypes); + std::vector dtts2 = d_nodeManager->mkMutualDatatypeTypes(dts); ASSERT_FALSE(dtts2[0].getDType().isFinite()); ASSERT_TRUE( -- 2.30.2