From 8d54316e7ff784a8d66da9ecc2d289ab463761c2 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 13 May 2011 22:02:52 +0000 Subject: [PATCH] added support for parametric datatypes, updated cvc parser to handle parametric datatypes, type ascriptions are not implemented yet --- src/expr/declaration_scope.cpp | 10 +- src/expr/declaration_scope.h | 5 + src/expr/expr_manager_template.cpp | 42 ++++- src/expr/expr_manager_template.h | 2 +- src/expr/node_manager.h | 3 +- src/expr/type.cpp | 57 +++++- src/expr/type.h | 19 +- src/expr/type_node.cpp | 14 ++ src/expr/type_node.h | 8 + src/parser/cvc/Cvc.g | 75 +++++--- src/parser/parser.cpp | 30 +++- src/parser/parser.h | 15 +- src/theory/datatypes/datatypes_rewriter.h | 3 +- src/theory/datatypes/kinds | 9 + src/theory/datatypes/theory_datatypes.cpp | 3 +- .../datatypes/theory_datatypes_type_rules.h | 169 +++++++++++++++--- src/util/datatype.cpp | 70 +++++++- src/util/datatype.h | 43 ++++- 18 files changed, 497 insertions(+), 80 deletions(-) diff --git a/src/expr/declaration_scope.cpp b/src/expr/declaration_scope.cpp index 8dd329b83..79accf43a 100644 --- a/src/expr/declaration_scope.cpp +++ b/src/expr/declaration_scope.cpp @@ -144,7 +144,10 @@ Type DeclarationScope::lookupType(const std::string& name, Debug("sort") << "instance is " << instantiation << endl; return instantiation; - } else { + }else if( p.second.isDatatype() ){ + Assert( DatatypeType( p.second ).isParametric() ); + return DatatypeType(p.second).instantiate(params); + }else { if(Debug.isOn("sort")) { Debug("sort") << "instantiating using a sort substitution" << endl; Debug("sort") << "have formals ["; @@ -167,6 +170,11 @@ Type DeclarationScope::lookupType(const std::string& name, } } +size_t DeclarationScope::lookupArity( const std::string& name ){ + pair, Type> p = (*d_typeMap->find(name)).second; + return p.first.size(); +} + void DeclarationScope::popScope() throw(ScopeException) { if( d_context->getLevel() == 0 ) { throw ScopeException(); diff --git a/src/expr/declaration_scope.h b/src/expr/declaration_scope.h index 699dca6fa..4cdb71ddc 100644 --- a/src/expr/declaration_scope.h +++ b/src/expr/declaration_scope.h @@ -174,6 +174,11 @@ public: Type lookupType(const std::string& name, const std::vector& params) const throw(AssertionException); + /** + * Lookup the arity of a bound parameterized type. + */ + size_t lookupArity( const std::string& name ); + /** * Pop a scope level. Deletes all bindings since the last call to * pushScope. Calls to pushScope and diff --git a/src/expr/expr_manager_template.cpp b/src/expr/expr_manager_template.cpp index f0c90ebdb..c32dbbc7d 100644 --- a/src/expr/expr_manager_template.cpp +++ b/src/expr/expr_manager_template.cpp @@ -486,12 +486,12 @@ DatatypeType ExprManager::mkDatatypeType(const Datatype& datatype) { std::vector ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes) { - return mkMutualDatatypeTypes(datatypes, set()); + return mkMutualDatatypeTypes(datatypes, set()); } std::vector ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes, - const std::set& unresolvedTypes) { + const std::set& unresolvedTypes) { NodeManagerScope nms(d_nodeManager); std::map nameResolutions; std::vector dtts; @@ -505,7 +505,18 @@ ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes, i_end = datatypes.end(); i != i_end; ++i) { - TypeNode* typeNode = new TypeNode(d_nodeManager->mkTypeConst(*i)); + TypeNode* typeNode; + if( (*i).getNumParameters()==0 ){ + typeNode = new TypeNode(d_nodeManager->mkTypeConst(*i)); + }else{ + TypeNode cons = d_nodeManager->mkTypeConst(*i); + std::vector< TypeNode > params; + params.push_back( cons ); + for( unsigned int ip=0; ip<(*i).getNumParameters(); ip++ ){ + params.push_back( TypeNode::fromType( (*i).getParameter( ip ) ) ); + } + typeNode = new TypeNode( d_nodeManager->mkTypeNode( kind::PARAMETRIC_DATATYPE, params ) ); + } Type type(d_nodeManager, typeNode); DatatypeType dtt(type); CheckArgument(nameResolutions.find((*i).getName()) == nameResolutions.end(), @@ -526,13 +537,21 @@ ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes, // // @TODO get rid of named resolutions altogether and handle // everything with these resolutions? + std::vector< SortConstructorType > paramTypes; + std::vector< DatatypeType > paramReplacements; std::vector placeholders;// to hold the "unresolved placeholders" std::vector replacements;// to hold our final, resolved types - for(std::set::const_iterator i = unresolvedTypes.begin(), + for(std::set::const_iterator i = unresolvedTypes.begin(), i_end = unresolvedTypes.end(); i != i_end; ++i) { - std::string name = (*i).getName(); + std::string name; + if( (*i).isSort() ){ + name = SortType(*i).getName(); + }else{ + Assert( (*i).isSortConstructor() ); + name = SortConstructorType(*i).getName(); + } std::map::const_iterator resolver = nameResolutions.find(name); CheckArgument(resolver != nameResolutions.end(), @@ -543,8 +562,14 @@ ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes, // unresolved SortType used as a placeholder in complex types) // with "(*resolver).second" (the DatatypeType we created in the // first step, above). - placeholders.push_back(*i); - replacements.push_back((*resolver).second); + if( (*i).isSort() ){ + placeholders.push_back(*i); + replacements.push_back( (*resolver).second ); + }else{ + Assert( (*i).isSortConstructor() ); + paramTypes.push_back( SortConstructorType(*i) ); + paramReplacements.push_back( (*resolver).second ); + } } // Lastly, perform the final resolutions and checks. @@ -555,7 +580,8 @@ ExprManager::mkMutualDatatypeTypes(const std::vector& datatypes, const Datatype& dt = (*i).getDatatype(); if(!dt.isResolved()) { const_cast(dt).resolve(this, nameResolutions, - placeholders, replacements); + placeholders, replacements, + paramTypes, paramReplacements); } // Now run some checks, including a check to make sure that no diff --git a/src/expr/expr_manager_template.h b/src/expr/expr_manager_template.h index f395d781c..712273473 100644 --- a/src/expr/expr_manager_template.h +++ b/src/expr/expr_manager_template.h @@ -357,7 +357,7 @@ public: */ std::vector mkMutualDatatypeTypes(const std::vector& datatypes, - const std::set& unresolvedTypes); + const std::set& unresolvedTypes); /** * Make a type representing a constructor with the given parameterization. diff --git a/src/expr/node_manager.h b/src/expr/node_manager.h index 9974df6ca..8b803e696 100644 --- a/src/expr/node_manager.h +++ b/src/expr/node_manager.h @@ -53,10 +53,12 @@ namespace expr { namespace attr { struct VarNameTag {}; struct SortArityTag {}; + struct DatatypeTag {}; }/* CVC4::expr::attr namespace */ typedef Attribute VarNameAttr; typedef Attribute SortArityAttr; +typedef Attribute DatatypeAttr; }/* CVC4::expr namespace */ @@ -1188,7 +1190,6 @@ inline TypeNode NodeManager::mkTypeNode(Kind kind, return NodeBuilder<>(this, kind).append(children).constructTypeNode(); } - inline Node NodeManager::mkVar(const std::string& name, const TypeNode& type) { Node n = mkVar(type); n.setAttribute(TypeAttr(), type); diff --git a/src/expr/type.cpp b/src/expr/type.cpp index 567bb2d40..2bcdcedfa 100644 --- a/src/expr/type.cpp +++ b/src/expr/type.cpp @@ -225,7 +225,7 @@ Type::operator DatatypeType() const throw(AssertionException) { /** Is this the Datatype type? */ bool Type::isDatatype() const { NodeManagerScope nms(d_nodeManager); - return d_typeNode->isDatatype(); + return d_typeNode->isDatatype() || d_typeNode->isParametricDatatype(); } /** Cast to a Constructor type */ @@ -388,6 +388,18 @@ string SortType::getName() const { return d_typeNode->getAttribute(expr::VarNameAttr()); } +bool SortType::isParameterized() const +{ + return false; +} + +/** Get the parameter types */ +std::vector SortType::getParamTypes() const +{ + vector params; + return params; +} + string SortConstructorType::getName() const { NodeManagerScope nms(d_nodeManager); return d_typeNode->getAttribute(expr::VarNameAttr()); @@ -514,7 +526,48 @@ std::vector ConstructorType::getArgTypes() const { } const Datatype& DatatypeType::getDatatype() const { - return d_typeNode->getConst(); + if( d_typeNode->isParametricDatatype() ){ + Assert( (*d_typeNode)[0].getKind()==kind::DATATYPE_TYPE ); + const Datatype& dt = (*d_typeNode)[0].getConst(); + return dt; + }else{ + return d_typeNode->getConst(); + } +} + +bool DatatypeType::isParametric() const { + return d_typeNode->isParametricDatatype(); +} + +size_t DatatypeType::getArity() const { + NodeManagerScope nms(d_nodeManager); + return d_typeNode->getNumChildren() - 1; +} + +std::vector DatatypeType::getParamTypes() const{ + NodeManagerScope nms(d_nodeManager); + vector params; + vector paramNodes = d_typeNode->getParamTypes(); + vector::iterator it = paramNodes.begin(); + vector::iterator it_end = paramNodes.end(); + for(; it != it_end; ++ it) { + params.push_back(makeType(*it)); + } + return params; +} + +DatatypeType DatatypeType::instantiate(const std::vector& params) const { + NodeManagerScope nms(d_nodeManager); + TypeNode cons = d_nodeManager->mkTypeConst( getDatatype() ); + vector paramsNodes; + paramsNodes.push_back( cons ); + for(vector::const_iterator i = params.begin(), + iend = params.end(); + i != iend; + ++i) { + paramsNodes.push_back(*getTypeNode(*i)); + } + return DatatypeType(makeType(d_nodeManager->mkTypeNode(kind::PARAMETRIC_DATATYPE,paramsNodes))); } DatatypeType SelectorType::getDomain() const { diff --git a/src/expr/type.h b/src/expr/type.h index 980a750d5..096336b0c 100644 --- a/src/expr/type.h +++ b/src/expr/type.h @@ -469,6 +469,12 @@ public: /** Get the name of the sort */ std::string getName() const; + + /** is parameterized */ + bool isParameterized() const; + + /** Get the parameter types */ + std::vector getParamTypes() const; };/* class SortType */ /** @@ -533,8 +539,19 @@ public: /** Get the underlying datatype */ const Datatype& getDatatype() const; -};/* class DatatypeType */ + /** is parameterized */ + bool isParametric() const; + + /** Get the parameter types */ + std::vector getParamTypes() const; + /** Get the arity of the datatype constructor */ + size_t getArity() const; + + /** Instantiate a datatype using this datatype constructor */ + DatatypeType instantiate(const std::vector& params) const; + +};/* class DatatypeType */ /** * Class encapsulating the constructor type diff --git a/src/expr/type_node.cpp b/src/expr/type_node.cpp index a6ca39015..9283da13a 100644 --- a/src/expr/type_node.cpp +++ b/src/expr/type_node.cpp @@ -136,6 +136,15 @@ std::vector TypeNode::getArgTypes() const { return args; } +std::vector TypeNode::getParamTypes() const { + vector params; + Assert( isParametricDatatype() ); + for(unsigned i = 1, i_end = getNumChildren(); i < i_end; ++i) { + params.push_back((*this)[i]); + } + return params; +} + TypeNode TypeNode::getRangeType() const { if(isTester()) { return NodeManager::currentNM()->booleanType(); @@ -185,6 +194,11 @@ bool TypeNode::isDatatype() const { return getKind() == kind::DATATYPE_TYPE; } +/** Is this a datatype type */ +bool TypeNode::isParametricDatatype() const { + return getKind() == kind::PARAMETRIC_DATATYPE; +} + /** Is this a constructor type */ bool TypeNode::isConstructor() const { return getKind() == kind::CONSTRUCTOR_TYPE; diff --git a/src/expr/type_node.h b/src/expr/type_node.h index 7f6ebd471..d6c685a75 100644 --- a/src/expr/type_node.h +++ b/src/expr/type_node.h @@ -451,6 +451,11 @@ public: */ std::vector getArgTypes() const; + /** + * Get the paramater types of a parameterized datatype. + */ + std::vector getParamTypes() const; + /** * Get the range type (i.e., the type of the result) of a function, * datatype constructor, datatype selector, or datatype tester. @@ -479,6 +484,9 @@ public: /** Is this a datatype type */ bool isDatatype() const; + /** Is this a parameterized datatype type */ + bool isParametricDatatype() const; + /** Is this a constructor type */ bool isConstructor() const; diff --git a/src/parser/cvc/Cvc.g b/src/parser/cvc/Cvc.g index b3c253dab..3c8d6e1ce 100644 --- a/src/parser/cvc/Cvc.g +++ b/src/parser/cvc/Cvc.g @@ -1002,12 +1002,27 @@ restrictedTypePossiblyFunctionLHS[CVC4::Type& t, } /* named types */ : identifier[id,check,SYM_SORT] - parameterization[check]? - { if(check == CHECK_DECLARED || + parameterization[check,types]? + { + if(check == CHECK_DECLARED || PARSER_STATE->isDeclared(id, SYM_SORT)) { - t = PARSER_STATE->getSort(id); + Debug("parser-param") << "param: getSort " << id << " " << types.size() << " " << PARSER_STATE->getArity( id ) + << " " << PARSER_STATE->isDeclared(id, SYM_SORT) << std::endl; + if( types.size()>0 ){ + t = PARSER_STATE->getSort(id, types); + }else{ + t = PARSER_STATE->getSort(id); + } } else { - t = PARSER_STATE->mkUnresolvedType(id); + if( types.empty() ){ + t = PARSER_STATE->mkUnresolvedType(id); + Debug("parser-param") << "param: make unres type " << id << std::endl; + }else{ + t = PARSER_STATE->mkUnresolvedTypeConstructor(id,types); + t = SortConstructorType(t).instantiate( types ); + Debug("parser-param") << "param: make unres param type " << id << " " << types.size() << " " + << PARSER_STATE->getArity( id ) << std::endl; + } } } @@ -1076,12 +1091,13 @@ restrictedTypePossiblyFunctionLHS[CVC4::Type& t, } ; -parameterization[CVC4::parser::DeclarationCheck check] +parameterization[CVC4::parser::DeclarationCheck check, + std::vector& params] @init { Type t; } - : LBRACKET restrictedType[t,check] ( COMMA restrictedType[t,check] )* RBRACKET - { UNSUPPORTED("parameterized types not yet supported"); } + : LBRACKET restrictedType[t,check] { Debug("parser-param") << "t = " << t << std::endl; params.push_back( t ); } + ( COMMA restrictedType[t,check] { Debug("parser-param") << "t = " << t << std::endl; params.push_back( t ); } )* RBRACKET ; bound returns [CVC4::parser::cvc::mySubrangeBound bound] @@ -1267,6 +1283,7 @@ term[CVC4::Expr& f] std::vector expressions; std::vector operators; unsigned op; + Type t; } : storeTerm[f] { expressions.push_back(f); } ( arithmeticBinop[op] storeTerm[f] { operators.push_back(op); expressions.push_back(f); } )* @@ -1374,6 +1391,7 @@ postfixTerm[CVC4::Expr& f] bool extract = false, left = false; std::vector args; std::string id; + Type t; } : bvTerm[f] ( /* array select / bitvector extract */ @@ -1443,7 +1461,9 @@ postfixTerm[CVC4::Expr& f] f = EXPR_MANAGER->mkVar(TupleType(f.getType()).getTypes()[k]); } )*/ )* - typeAscription[f]? + (typeAscription[f, t] + { //f = MK_EXPR(CVC4::kind::APPLY_TYPE_ANNOTATION, MK_CONST(t), f); + } )? ; bvTerm[CVC4::Expr& f] @@ -1642,20 +1662,19 @@ simpleTerm[CVC4::Expr& f] * Matches (and performs) a type ascription. * The f arg is the term to check (it is an input-only argument). */ -typeAscription[const CVC4::Expr& f] +typeAscription[const CVC4::Expr& f, CVC4::Type& t] @init { - Type t; } : COLON COLON type[t,CHECK_DECLARED] - { if(f.getType() != t) { - std::stringstream ss; - ss << Expr::setlanguage(language::output::LANG_CVC4) - << "type ascription not satisfied\n" - << "term: " << f << '\n' - << "has type: " << f.getType() << '\n' - << "ascribed: " << t; - PARSER_STATE->parseError(ss.str()); - } + { //if(f.getType() != t) { + // std::stringstream ss; + // ss << Expr::setlanguage(language::output::LANG_CVC4) + // << "type ascription not satisfied\n" + // << "term: " << f << '\n' + // << "has type: " << f.getType() << '\n' + // << "ascribed: " << t; + // PARSER_STATE->parseError(ss.str()); + //} } ; @@ -1706,16 +1725,23 @@ iteElseTerm[CVC4::Expr& f] datatypeDef[std::vector& datatypes] @init { std::string id, id2; + Type t; + std::vector< Type > params; } /* This really needs to be CHECK_NONE, or mutually-recursive datatypes * won't work, because this type will already be "defined" as an * unresolved type; don't worry, we check below. */ - : identifier[id,CHECK_NONE,SYM_SORT] - ( LBRACKET identifier[id2,CHECK_NONE,SYM_SORT] - ( COMMA identifier[id2,CHECK_NONE,SYM_SORT] )* RBRACKET - { UNSUPPORTED("parameterized datatypes not yet supported"); } + : identifier[id,CHECK_NONE,SYM_SORT] { PARSER_STATE->pushScope(); } + ( LBRACKET identifier[id2,CHECK_UNDECLARED,SYM_SORT] { + t = PARSER_STATE->mkSort(id2); + params.push_back( t ); + } + ( COMMA identifier[id2,CHECK_UNDECLARED,SYM_SORT] { + t = PARSER_STATE->mkSort(id2); + params.push_back( t ); } + )* RBRACKET )? - { datatypes.push_back(Datatype(id)); + { datatypes.push_back(Datatype(id,params)); if(!PARSER_STATE->isUnresolvedType(id)) { // if not unresolved, must be undeclared PARSER_STATE->checkDeclaration(id, CHECK_UNDECLARED, SYM_SORT); @@ -1723,6 +1749,7 @@ datatypeDef[std::vector& datatypes] } EQUAL_TOK constructorDef[datatypes.back()] ( BAR constructorDef[datatypes.back()] )* + { PARSER_STATE->popScope(); } ; /** diff --git a/src/parser/parser.cpp b/src/parser/parser.cpp index 29ade43c1..efe01fb40 100644 --- a/src/parser/parser.cpp +++ b/src/parser/parser.cpp @@ -95,6 +95,11 @@ Type Parser::getSort(const std::string& name, return t; } +size_t Parser::getArity(const std::string& sort_name){ + Assert( isDeclared(sort_name, SYM_SORT) ); + return d_declScope->lookupArity(sort_name); +} + /* Returns true if name is bound to a boolean variable. */ bool Parser::isBoolean(const std::string& name) { return isDeclared(name, SYM_VARIABLE) && getType(name).isBoolean(); @@ -237,6 +242,24 @@ SortType Parser::mkUnresolvedType(const std::string& name) { return unresolved; } +SortConstructorType Parser::mkUnresolvedTypeConstructor(const std::string& name, + size_t arity) +{ + SortConstructorType unresolved = mkSortConstructor(name,arity); + d_unresolved.insert(unresolved); + return unresolved; +} +SortConstructorType Parser::mkUnresolvedTypeConstructor(const std::string& name, + const std::vector& params){ + Debug("parser") << "newSortConstructor(P)(" << name << ", " << params.size() << ")" + << std::endl; + SortConstructorType unresolved = d_exprManager->mkSortConstructor(name, params.size()); + defineType(name, params, unresolved); + Type t = getSort( name, params ); + d_unresolved.insert(unresolved); + return unresolved; +} + bool Parser::isUnresolvedType(const std::string& name) { if(!isDeclared(name, SYM_SORT)) { return false; @@ -260,7 +283,12 @@ Parser::mkMutualDatatypeTypes(const std::vector& datatypes) { if(isDeclared(name, SYM_SORT)) { throw ParserException(name + " already declared"); } - defineType(name, t); + if( t.isParametric() ){ + std::vector< Type > paramTypes = t.getParamTypes(); + defineType(name, paramTypes, t ); + }else{ + defineType(name, t); + } for(Datatype::const_iterator j = dt.begin(), j_end = dt.end(); j != j_end; diff --git a/src/parser/parser.h b/src/parser/parser.h index 6509b192b..6d55e195e 100644 --- a/src/parser/parser.h +++ b/src/parser/parser.h @@ -150,7 +150,7 @@ class CVC4_PUBLIC Parser { * depend on mkMutualDatatypeTypes() to check everything and clear * this out. */ - std::set d_unresolved; + std::set d_unresolved; /** * "Preemption commands": extra commands implied by subterms that @@ -254,6 +254,11 @@ public: Type getSort(const std::string& sort_name, const std::vector& params); + /** + * Returns arity of a (parameterized) sort, given a name and args. + */ + size_t getArity(const std::string& sort_name); + /** * Checks if a symbol has been declared. * @param name the symbol name @@ -367,6 +372,14 @@ public: */ SortType mkUnresolvedType(const std::string& name); + /** + * Creates a new "unresolved type," used only during parsing. + */ + SortConstructorType mkUnresolvedTypeConstructor(const std::string& name, + size_t arity); + SortConstructorType mkUnresolvedTypeConstructor(const std::string& name, + const std::vector& params); + /** * Returns true IFF name is an unresolved type. */ diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index eea15d6b6..df8cb6eb8 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -43,7 +43,8 @@ public: return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(result)); } else { - const Datatype& dt = in[0].getType().getConst(); + //const Datatype& dt = in[0].getType().getConst(); + const Datatype& dt = DatatypeType(in[0].getType().toType()).getDatatype(); if(dt.getNumConstructors() == 1) { // only one constructor, so it must be Debug("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds index 37a65a2b0..d08e3875c 100644 --- a/src/theory/datatypes/kinds +++ b/src/theory/datatypes/kinds @@ -53,4 +53,13 @@ well-founded DATATYPE_TYPE \ "%TYPE%.getConst().mkGroundTerm()" \ "util/datatype.h" +operator PARAMETRIC_DATATYPE 1: "parametric datatype" +cardinality PARAMETRIC_DATATYPE \ + "DatatypeType(%TYPE%.toType()).getDatatype().getCardinality()" \ + "util/datatype.h" +well-founded PARAMETRIC_DATATYPE\ + "DatatypeType(%TYPE%.toType()).getDatatype().isWellFounded()" \ + "DatatypeType(%TYPE%.toType()).getDatatype().mkGroundTerm()" \ + "util/datatype.h" + endtheory diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 6808ef749..2f0b82f7c 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -851,7 +851,8 @@ void TheoryDatatypes::addTermToLabels( Node t ) { const Datatype& dt = ((DatatypeType)(t.getType()).toType()).getDatatype(); if( dt.getNumConstructors()==1 ){ Node tester = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[0].getTester() ), t ); - addTester( tester ); + lbl->push_back( tester ); + d_checkMap[ t ] = true; d_em.addNodeAxiom( tester, Reason::idt_texhaust ); } d_labels.insertDataFromContextMemory(tmp, lbl); diff --git a/src/theory/datatypes/theory_datatypes_type_rules.h b/src/theory/datatypes/theory_datatypes_type_rules.h index bc1581f14..dc2e95f9d 100644 --- a/src/theory/datatypes/theory_datatypes_type_rules.h +++ b/src/theory/datatypes/theory_datatypes_type_rules.h @@ -32,6 +32,70 @@ namespace expr { namespace theory { namespace datatypes { +class Matcher +{ +private: + std::vector< TypeNode > d_types; + std::vector< TypeNode > d_match; +public: + Matcher(){} + Matcher( DatatypeType dt ){ + std::vector< Type > argTypes = dt.getParamTypes(); + addTypes( argTypes ); + } + ~Matcher(){} + + void addType( Type t ){ + d_types.push_back( TypeNode::fromType( t ) ); + d_match.push_back( TypeNode::null() ); + } + void addTypes( std::vector< Type > types ){ + for( int i=0; i<(int)types.size(); i++ ){ + addType( types[i] ); + } + } + + bool doMatching( TypeNode base, TypeNode match ){ + std::vector< TypeNode >::iterator i = std::find( d_types.begin(), d_types.end(), base ); + if( i!=d_types.end() ){ + int index = i - d_types.begin(); + if( !d_match[index].isNull() && d_match[index]!=match ){ + return false; + }else{ + d_match[ i - d_types.begin() ] = match; + return true; + } + }else if( base==match ){ + return true; + }else if( base.getKind()!=match.getKind() || base.getNumChildren()!=match.getNumChildren() ){ + return false; + }else{ + for( int i=0; i<(int)base.getNumChildren(); i++ ){ + if( !doMatching( base[i], match[i] ) ){ + return false; + } + } + return true; + } + } + + TypeNode getMatch( unsigned int i ){ return d_match[i]; } + void getTypes( std::vector& types ) { + types.clear(); + for( int i=0; i<(int)d_match.size(); i++ ){ + types.push_back( d_types[i].toType() ); + } + } + void getMatches( std::vector& types ) { + types.clear(); + for( int i=0; i<(int)d_match.size(); i++ ){ + Assert( !d_match[i].isNull() ); //verify that all types have been set + types.push_back( d_match[i].toType() ); + } + } +}; + + typedef expr::Attribute GroundTermAttr; struct DatatypeConstructorTypeRule { @@ -39,24 +103,43 @@ struct DatatypeConstructorTypeRule { throw(TypeCheckingExceptionPrivate) { Assert(n.getKind() == kind::APPLY_CONSTRUCTOR); TypeNode consType = n.getOperator().getType(check); - if(check) { - Debug("typecheck-idt") << "typecheck cons: " << n << " " << n.getNumChildren() << std::endl; - Debug("typecheck-idt") << "cons type: " << consType << " " << consType.getNumChildren() << std::endl; - if(n.getNumChildren() != consType.getNumChildren() - 1) { - throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the constructor type"); - } - TNode::iterator child_it = n.begin(); - TNode::iterator child_it_end = n.end(); - TypeNode::iterator tchild_it = consType.begin(); + Type t = consType.getConstructorRangeType().toType(); + Assert( t.isDatatype() ); + DatatypeType dt = DatatypeType(t); + TNode::iterator child_it = n.begin(); + TNode::iterator child_it_end = n.end(); + TypeNode::iterator tchild_it = consType.begin(); + if( ( dt.isParametric() || check ) && n.getNumChildren() != consType.getNumChildren() - 1 ){ + throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the constructor type"); + } + if( dt.isParametric() ){ + Debug("typecheck-idt") << "typecheck parameterized datatype " << n << std::endl; + Matcher m( dt ); for(; child_it != child_it_end; ++child_it, ++tchild_it) { TypeNode childType = (*child_it).getType(check); - Debug("typecheck-idt") << "typecheck cons arg: " << childType << " " << (*tchild_it) << std::endl; - if(childType != *tchild_it) { - throw TypeCheckingExceptionPrivate(n, "bad type for constructor argument"); + if( !m.doMatching( *tchild_it, childType ) ){ + throw TypeCheckingExceptionPrivate(n, "matching failed for parameterized constructor"); + } + } + std::vector< Type > instTypes; + m.getMatches( instTypes ); + TypeNode range = TypeNode::fromType( dt.instantiate( instTypes ) ); + Debug("typecheck-idt") << "Return " << range << std::endl; + return range; + }else{ + if(check) { + Debug("typecheck-idt") << "typecheck cons: " << n << " " << n.getNumChildren() << std::endl; + Debug("typecheck-idt") << "cons type: " << consType << " " << consType.getNumChildren() << std::endl; + for(; child_it != child_it_end; ++child_it, ++tchild_it) { + TypeNode childType = (*child_it).getType(check); + Debug("typecheck-idt") << "typecheck cons arg: " << childType << " " << (*tchild_it) << std::endl; + if(childType != *tchild_it) { + throw TypeCheckingExceptionPrivate(n, "bad type for constructor argument"); + } } } + return consType.getConstructorRangeType(); } - return consType.getConstructorRangeType(); } };/* struct DatatypeConstructorTypeRule */ @@ -65,18 +148,38 @@ struct DatatypeSelectorTypeRule { throw(TypeCheckingExceptionPrivate) { Assert(n.getKind() == kind::APPLY_SELECTOR); TypeNode selType = n.getOperator().getType(check); - Debug("typecheck-idt") << "typecheck sel: " << n << std::endl; - Debug("typecheck-idt") << "sel type: " << selType << std::endl; - if(check) { - if(n.getNumChildren() != 1) { - throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the selector type"); - } + Type t = selType[0].toType(); + Assert( t.isDatatype() ); + DatatypeType dt = DatatypeType(t); + if( ( dt.isParametric() || check ) && n.getNumChildren() != 1 ){ + throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the selector type"); + } + if( dt.isParametric() ){ + Debug("typecheck-idt") << "typecheck parameterized sel: " << n << std::endl; + Matcher m( dt ); TypeNode childType = n[0].getType(check); - if(selType[0] != childType) { - throw TypeCheckingExceptionPrivate(n, "bad type for selector argument"); + if( !m.doMatching( selType[0], childType ) ){ + throw TypeCheckingExceptionPrivate(n, "matching failed for selector argument of parameterized datatype"); } + std::vector< Type > types, matches; + m.getTypes( types ); + m.getMatches( matches ); + Type range = selType[1].toType(); + range = range.substitute( types, matches ); + Debug("typecheck-idt") << "Return " << range << std::endl; + return TypeNode::fromType( range ); + }else{ + if(check) { + Debug("typecheck-idt") << "typecheck sel: " << n << std::endl; + Debug("typecheck-idt") << "sel type: " << selType << std::endl; + TypeNode childType = n[0].getType(check); + if(selType[0] != childType) { + Debug("typecheck-idt") << "ERROR: " << selType[0].getKind() << " " << childType.getKind() << std::endl; + throw TypeCheckingExceptionPrivate(n, "bad type for selector argument"); + } + } + return selType[1]; } - return selType[1]; } };/* struct DatatypeSelectorTypeRule */ @@ -90,10 +193,21 @@ struct DatatypeTesterTypeRule { } TypeNode testType = n.getOperator().getType(check); TypeNode childType = n[0].getType(check); - Debug("typecheck-idt") << "typecheck test: " << n << std::endl; - Debug("typecheck-idt") << "test type: " << testType << std::endl; - if(testType[0] != childType) { - throw TypeCheckingExceptionPrivate(n, "bad type for tester argument"); + Type t = testType[0].toType(); + Assert( t.isDatatype() ); + DatatypeType dt = DatatypeType(t); + if( dt.isParametric() ){ + Debug("typecheck-idt") << "typecheck parameterized tester: " << n << std::endl; + Matcher m( dt ); + if( !m.doMatching( testType[0], childType ) ){ + throw TypeCheckingExceptionPrivate(n, "matching failed for tester argument of parameterized datatype"); + } + }else{ + Debug("typecheck-idt") << "typecheck test: " << n << std::endl; + Debug("typecheck-idt") << "test type: " << testType << std::endl; + if(testType[0] != childType) { + throw TypeCheckingExceptionPrivate(n, "bad type for tester argument"); + } } } return nodeManager->booleanType(); @@ -140,7 +254,8 @@ struct ConstructorProperties { // Constructors within the same Datatype could share the same // type. So we scan through the datatype to find one that // matches. - const Datatype& dt = type[type.getNumChildren() - 1].getConst(); + //const Datatype& dt = type[type.getNumChildren() - 1].getConst(); + const Datatype& dt = DatatypeType(type[type.getNumChildren() - 1].toType()).getDatatype(); for(Datatype::const_iterator i = dt.begin(), i_end = dt.end(); i != i_end; diff --git a/src/util/datatype.cpp b/src/util/datatype.cpp index ab52e7f93..ecb089658 100644 --- a/src/util/datatype.cpp +++ b/src/util/datatype.cpp @@ -54,13 +54,15 @@ const Datatype& Datatype::datatypeOf(Expr item) { TypeNode t = Node::fromExpr(item).getType(); switch(t.getKind()) { case kind::CONSTRUCTOR_TYPE: - return t[t.getNumChildren() - 1].getConst(); + //return t[t.getNumChildren() - 1].getConst(); + return DatatypeType(t[t.getNumChildren() - 1].toType()).getDatatype(); case kind::SELECTOR_TYPE: case kind::TESTER_TYPE: - return t[0].getConst(); + //return t[0].getConst(); + return DatatypeType(t[0].toType()).getDatatype(); default: Unhandled("arg must be a datatype constructor, selector, or tester"); - } + } } size_t Datatype::indexOf(Expr item) { @@ -77,9 +79,12 @@ size_t Datatype::indexOf(Expr item) { void Datatype::resolve(ExprManager* em, const std::map& resolutions, const std::vector& placeholders, - const std::vector& replacements) + const std::vector& replacements, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements) throw(AssertionException, DatatypeResolutionException) { + //cout << "resolve " << *this << "..." << std::endl; AssertArgument(em != NULL, "cannot resolve a Datatype with a NULL expression manager"); CheckArgument(!d_resolved, "cannot resolve a Datatype twice"); AssertArgument(resolutions.find(d_name) != resolutions.end(), @@ -92,13 +97,15 @@ void Datatype::resolve(ExprManager* em, d_resolved = true; size_t index = 0; for(iterator i = begin(), i_end = end(); i != i_end; ++i) { - (*i).resolve(em, self, resolutions, placeholders, replacements); + (*i).resolve(em, self, resolutions, placeholders, replacements, paramTypes, paramReplacements); Assert((*i).isResolved()); Node::fromExpr((*i).d_constructor).setAttribute(DatatypeIndexAttr(), index); Node::fromExpr((*i).d_tester).setAttribute(DatatypeIndexAttr(), index++); } d_self = self; Assert(index == getNumConstructors()); + + //cout << "done resolve " << *this << std::endl; } void Datatype::addConstructor(const Constructor& c) { @@ -263,10 +270,16 @@ Expr Datatype::mkGroundTerm() const throw(AssertionException) { DatatypeType Datatype::getDatatypeType() const throw(AssertionException) { CheckArgument(isResolved(), *this, "Datatype must be resolved to get its DatatypeType"); - Assert(!d_self.isNull()); + Assert(!d_self.isNull() && !DatatypeType(d_self).isParametric()); return DatatypeType(d_self); } +DatatypeType Datatype::getDatatypeType(const std::vector& params) const throw(AssertionException) { + CheckArgument(isResolved(), *this, "Datatype must be resolved to get its DatatypeType"); + Assert(!d_self.isNull() && DatatypeType(d_self).isParametric()); + return DatatypeType(d_self).instantiate(params); +} + bool Datatype::operator==(const Datatype& other) const throw() { // two datatypes are == iff the name is the same and they have // exactly matching constructors (in the same order) @@ -349,8 +362,13 @@ const Datatype::Constructor& Datatype::operator[](size_t index) const { void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self, const std::map& resolutions, const std::vector& placeholders, - const std::vector& replacements) + const std::vector& replacements, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements) throw(AssertionException, DatatypeResolutionException) { + + //cout << "resolve " << *this << "..." << std::endl; + AssertArgument(em != NULL, "cannot resolve a Datatype with a NULL expression manager"); CheckArgument(!isResolved(), "cannot resolve a Datatype constructor twice; " @@ -383,6 +401,9 @@ void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self, if(!placeholders.empty()) { range = range.substitute(placeholders, replacements); } + if(!paramTypes.empty() ){ + range = doParametricSubstitution( range, paramTypes, paramReplacements ); + } (*i).d_selector = em->mkVar((*i).d_name, em->mkSelectorType(self, range)); } Node::fromExpr((*i).d_selector).setAttribute(DatatypeIndexAttr(), index++); @@ -403,6 +424,37 @@ void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self, for(iterator i = begin(), i_end = end(); i != i_end; ++i) { (*i).d_constructor = d_constructor; } + + //cout << "done resolve " << *this << std::endl; +} + +Type Datatype::Constructor::doParametricSubstitution( Type range, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements ){ + TypeNode typn = TypeNode::fromType( range ); + if(typn.getNumChildren() == 0) { + return range; + } else { + std::vector< Type > origChildren; + std::vector< Type > children; + for(TypeNode::const_iterator i = typn.begin(), iend = typn.end();i != iend; ++i) { + origChildren.push_back( (*i).toType() ); + children.push_back( doParametricSubstitution( (*i).toType(), paramTypes, paramReplacements ) ); + } + for( int i=0; i<(int)paramTypes.size(); i++ ){ + if( paramTypes[i].getArity()==origChildren.size() ){ + Type tn = paramTypes[i].instantiate( origChildren ); + if( range==tn ){ + return paramReplacements[i].instantiate( children ); + } + } + } + NodeBuilder<> nb(typn.getKind()); + for( int i=0; i<(int)children.size(); i++ ){ + nb << TypeNode::fromType( children[i] ); + } + return nb.constructTypeNode().toType(); + } } Datatype::Constructor::Constructor(std::string name, std::string tester) : @@ -615,6 +667,10 @@ Expr Datatype::Constructor::Arg::getConstructor() const { return d_constructor; } +Type Datatype::Constructor::Arg::getSelectorType() const{ + return getSelector().getType(); +} + bool Datatype::Constructor::Arg::isUnresolvedSelf() const throw() { return d_selector.isNull() && d_name.size() == d_name.find('\0') + 1; } diff --git a/src/util/datatype.h b/src/util/datatype.h index 7d9ae6f39..abc9e3258 100644 --- a/src/util/datatype.h +++ b/src/util/datatype.h @@ -174,6 +174,12 @@ public: */ Expr getConstructor() const; + /** + * Get the selector for this constructor argument; this call is + * only permitted after resolution. + */ + Type getSelectorType() const; + /** * Get the name of the type of this constructor argument * (Datatype field). Can be used for not-yet-resolved Datatypes @@ -204,10 +210,16 @@ public: void resolve(ExprManager* em, DatatypeType self, const std::map& resolutions, const std::vector& placeholders, - const std::vector& replacements) + const std::vector& replacements, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements) throw(AssertionException, DatatypeResolutionException); friend class Datatype; + /** */ + Type doParametricSubstitution( Type range, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements ); public: /** * Create a new Datatype constructor with the given name for the @@ -314,6 +326,7 @@ public: private: std::string d_name; + std::vector d_params; std::vector d_constructors; bool d_resolved; Type d_self; @@ -330,14 +343,16 @@ private: void resolve(ExprManager* em, const std::map& resolutions, const std::vector& placeholders, - const std::vector& replacements) + const std::vector& replacements, + const std::vector< SortConstructorType >& paramTypes, + const std::vector< DatatypeType >& paramReplacements) throw(AssertionException, DatatypeResolutionException); friend class ExprManager;// for access to resolve() public: /** Create a new Datatype of the given name. */ - inline explicit Datatype(std::string name); + inline explicit Datatype(std::string name, std::vector& params); /** Add a constructor to this Datatype. */ void addConstructor(const Constructor& c); @@ -347,6 +362,11 @@ public: /** Get the number of constructors (so far) for this Datatype. */ inline size_t getNumConstructors() const throw(); + /** Get the nubmer of parameters */ + inline size_t getNumParameters() const throw(); + /** Get parameter */ + inline Type getParameter( unsigned int i ) const; + /** * Return the cardinality of this datatype (the sum of the * cardinalities of its constructors). The Datatype must be @@ -381,6 +401,12 @@ public: */ DatatypeType getDatatypeType() const throw(AssertionException); + /** + * Get the DatatypeType associated to this (parameterized) Datatype. Can only be + * called post-resolution. + */ + DatatypeType getDatatypeType(const std::vector& params) const throw(AssertionException); + /** * Return true iff the two Datatypes are the same. * @@ -466,8 +492,9 @@ inline std::string Datatype::UnresolvedType::getName() const throw() { return d_name; } -inline Datatype::Datatype(std::string name) : +inline Datatype::Datatype(std::string name, std::vector& params) : d_name(name), + d_params(params), d_constructors(), d_resolved(false), d_self() { @@ -481,6 +508,14 @@ inline size_t Datatype::getNumConstructors() const throw() { return d_constructors.size(); } +inline size_t Datatype::getNumParameters() const throw() { + return d_params.size(); +} + +inline Type Datatype::getParameter( unsigned int i ) const { + return d_params[i]; +} + inline bool Datatype::operator!=(const Datatype& other) const throw() { return !(*this == other); } -- 2.30.2