added support for parametric datatypes, updated cvc parser to handle parametric datat...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 13 May 2011 22:02:52 +0000 (22:02 +0000)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 13 May 2011 22:02:52 +0000 (22:02 +0000)
18 files changed:
src/expr/declaration_scope.cpp
src/expr/declaration_scope.h
src/expr/expr_manager_template.cpp
src/expr/expr_manager_template.h
src/expr/node_manager.h
src/expr/type.cpp
src/expr/type.h
src/expr/type_node.cpp
src/expr/type_node.h
src/parser/cvc/Cvc.g
src/parser/parser.cpp
src/parser/parser.h
src/theory/datatypes/datatypes_rewriter.h
src/theory/datatypes/kinds
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes_type_rules.h
src/util/datatype.cpp
src/util/datatype.h

index 8dd329b8390fb1523c77944f1ad81c5807eb3903..79accf43a066f808e94efd0957812971d33d07a6 100644 (file)
@@ -144,7 +144,10 @@ Type DeclarationScope::lookupType(const std::string& name,
     Debug("sort") << "instance is  " << instantiation << endl;
 
     return instantiation;
-  } else {
+  }else if( p.second.isDatatype() ){
+    Assert( DatatypeType( p.second ).isParametric() );
+    return DatatypeType(p.second).instantiate(params);
+  }else {
     if(Debug.isOn("sort")) {
       Debug("sort") << "instantiating using a sort substitution" << endl;
       Debug("sort") << "have formals [";
@@ -167,6 +170,11 @@ Type DeclarationScope::lookupType(const std::string& name,
   }
 }
 
+size_t DeclarationScope::lookupArity( const std::string& name ){
+  pair<vector<Type>, Type> p = (*d_typeMap->find(name)).second;
+  return p.first.size();
+}
+
 void DeclarationScope::popScope() throw(ScopeException) {
   if( d_context->getLevel() == 0 ) {
     throw ScopeException();
index 699dca6fa6f4ed2f06e5f4b16716a9baf994c165..4cdb71ddcc5d3ec04a0d94b5e30e626628896fdd 100644 (file)
@@ -174,6 +174,11 @@ public:
   Type lookupType(const std::string& name,
                   const std::vector<Type>& params) const throw(AssertionException);
 
+  /**
+   * Lookup the arity of a bound parameterized type.
+   */
+  size_t lookupArity( const std::string& name );
+
   /**
    * Pop a scope level. Deletes all bindings since the last call to
    * <code>pushScope</code>. Calls to <code>pushScope</code> and
index f0c90ebdb042a9bfc27b9184ac67f32ab537b488..c32dbbc7d14219b5dfcd3e23c3f9a728c757b2b3 100644 (file)
@@ -486,12 +486,12 @@ DatatypeType ExprManager::mkDatatypeType(const Datatype& datatype) {
 
 std::vector<DatatypeType>
 ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes) {
-  return mkMutualDatatypeTypes(datatypes, set<SortType>());
+  return mkMutualDatatypeTypes(datatypes, set<Type>());
 }
 
 std::vector<DatatypeType>
 ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
-                                   const std::set<SortType>& unresolvedTypes) {
+                                   const std::set<Type>& unresolvedTypes) {
   NodeManagerScope nms(d_nodeManager);
   std::map<std::string, DatatypeType> nameResolutions;
   std::vector<DatatypeType> dtts;
@@ -505,7 +505,18 @@ ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
         i_end = datatypes.end();
       i != i_end;
       ++i) {
-    TypeNode* typeNode = new TypeNode(d_nodeManager->mkTypeConst(*i));
+    TypeNode* typeNode;
+    if( (*i).getNumParameters()==0 ){
+      typeNode = new TypeNode(d_nodeManager->mkTypeConst(*i));
+    }else{
+      TypeNode cons = d_nodeManager->mkTypeConst(*i);
+      std::vector< TypeNode > params;
+      params.push_back( cons );
+      for( unsigned int ip=0; ip<(*i).getNumParameters(); ip++ ){
+        params.push_back( TypeNode::fromType( (*i).getParameter( ip ) ) );
+      }
+      typeNode = new TypeNode( d_nodeManager->mkTypeNode( kind::PARAMETRIC_DATATYPE, params ) );
+    }
     Type type(d_nodeManager, typeNode);
     DatatypeType dtt(type);
     CheckArgument(nameResolutions.find((*i).getName()) == nameResolutions.end(),
@@ -526,13 +537,21 @@ ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
   //
   // @TODO get rid of named resolutions altogether and handle
   // everything with these resolutions?
+  std::vector< SortConstructorType > paramTypes;
+  std::vector< DatatypeType > paramReplacements;
   std::vector<Type> placeholders;// to hold the "unresolved placeholders"
   std::vector<Type> replacements;// to hold our final, resolved types
-  for(std::set<SortType>::const_iterator i = unresolvedTypes.begin(),
+  for(std::set<Type>::const_iterator i = unresolvedTypes.begin(),
         i_end = unresolvedTypes.end();
       i != i_end;
       ++i) {
-    std::string name = (*i).getName();
+    std::string name;
+    if( (*i).isSort() ){
+      name = SortType(*i).getName();
+    }else{
+      Assert( (*i).isSortConstructor() );
+      name = SortConstructorType(*i).getName();
+    }
     std::map<std::string, DatatypeType>::const_iterator resolver =
       nameResolutions.find(name);
     CheckArgument(resolver != nameResolutions.end(),
@@ -543,8 +562,14 @@ ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
     // unresolved SortType used as a placeholder in complex types)
     // with "(*resolver).second" (the DatatypeType we created in the
     // first step, above).
-    placeholders.push_back(*i);
-    replacements.push_back((*resolver).second);
+    if( (*i).isSort() ){
+      placeholders.push_back(*i);
+      replacements.push_back( (*resolver).second );
+    }else{
+      Assert( (*i).isSortConstructor() );
+      paramTypes.push_back( SortConstructorType(*i) );
+      paramReplacements.push_back( (*resolver).second );
+    }
   }
 
   // Lastly, perform the final resolutions and checks.
@@ -555,7 +580,8 @@ ExprManager::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
     const Datatype& dt = (*i).getDatatype();
     if(!dt.isResolved()) {
       const_cast<Datatype&>(dt).resolve(this, nameResolutions,
-                                        placeholders, replacements);
+                                        placeholders, replacements,
+                                        paramTypes, paramReplacements);
     }
 
     // Now run some checks, including a check to make sure that no
index f395d781c457ef075fa71274e3b93cc892ee2980..712273473f7f79ce7a6aabcea0d0b7ddf210dbda 100644 (file)
@@ -357,7 +357,7 @@ public:
    */
   std::vector<DatatypeType>
   mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes,
-                        const std::set<SortType>& unresolvedTypes);
+                        const std::set<Type>& unresolvedTypes);
 
   /**
    * Make a type representing a constructor with the given parameterization.
index 9974df6ca32db092d75ce18245eae51c2b5db5de..8b803e696251f4d5145e74548d88ae9a99b51edc 100644 (file)
@@ -53,10 +53,12 @@ namespace expr {
 namespace attr {
   struct VarNameTag {};
   struct SortArityTag {};
+  struct DatatypeTag {};
 }/* CVC4::expr::attr namespace */
 
 typedef Attribute<attr::VarNameTag, std::string> VarNameAttr;
 typedef Attribute<attr::SortArityTag, uint64_t> SortArityAttr;
+typedef Attribute<attr::SortArityTag, void*> DatatypeAttr;
 
 }/* CVC4::expr namespace */
 
@@ -1188,7 +1190,6 @@ inline TypeNode NodeManager::mkTypeNode(Kind kind,
   return NodeBuilder<>(this, kind).append(children).constructTypeNode();
 }
 
-
 inline Node NodeManager::mkVar(const std::string& name, const TypeNode& type) {
   Node n = mkVar(type);
   n.setAttribute(TypeAttr(), type);
index 567bb2d40af7d2a3a5c54947206469ba3a7de129..2bcdcedfa513754f65a582661b79001d2ca9bb16 100644 (file)
@@ -225,7 +225,7 @@ Type::operator DatatypeType() const throw(AssertionException) {
 /** Is this the Datatype type? */
 bool Type::isDatatype() const {
   NodeManagerScope nms(d_nodeManager);
-  return d_typeNode->isDatatype();
+  return d_typeNode->isDatatype() || d_typeNode->isParametricDatatype();
 }
 
 /** Cast to a Constructor type */
@@ -388,6 +388,18 @@ string SortType::getName() const {
   return d_typeNode->getAttribute(expr::VarNameAttr());
 }
 
+bool SortType::isParameterized() const
+{
+  return false;
+}
+
+/** Get the parameter types */
+std::vector<Type> SortType::getParamTypes() const
+{
+  vector<Type> params;
+  return params;
+}
+
 string SortConstructorType::getName() const {
   NodeManagerScope nms(d_nodeManager);
   return d_typeNode->getAttribute(expr::VarNameAttr());
@@ -514,7 +526,48 @@ std::vector<Type> ConstructorType::getArgTypes() const {
 }
 
 const Datatype& DatatypeType::getDatatype() const {
-  return d_typeNode->getConst<Datatype>();
+  if( d_typeNode->isParametricDatatype() ){
+    Assert( (*d_typeNode)[0].getKind()==kind::DATATYPE_TYPE );
+    const Datatype& dt = (*d_typeNode)[0].getConst<Datatype>();
+    return dt;
+  }else{
+    return d_typeNode->getConst<Datatype>();
+  }
+}
+
+bool DatatypeType::isParametric() const {
+  return d_typeNode->isParametricDatatype();
+}
+
+size_t DatatypeType::getArity() const {
+  NodeManagerScope nms(d_nodeManager);
+  return d_typeNode->getNumChildren() - 1;
+}
+
+std::vector<Type> DatatypeType::getParamTypes() const{
+  NodeManagerScope nms(d_nodeManager);
+  vector<Type> params;
+  vector<TypeNode> paramNodes = d_typeNode->getParamTypes();
+  vector<TypeNode>::iterator it = paramNodes.begin();
+  vector<TypeNode>::iterator it_end = paramNodes.end();
+  for(; it != it_end; ++ it) {
+    params.push_back(makeType(*it));
+  }
+  return params;
+}
+
+DatatypeType DatatypeType::instantiate(const std::vector<Type>& params) const {
+  NodeManagerScope nms(d_nodeManager);
+  TypeNode cons = d_nodeManager->mkTypeConst( getDatatype() );
+  vector<TypeNode> paramsNodes;
+  paramsNodes.push_back( cons );
+  for(vector<Type>::const_iterator i = params.begin(),
+        iend = params.end();
+      i != iend;
+      ++i) {
+    paramsNodes.push_back(*getTypeNode(*i));
+  }
+  return DatatypeType(makeType(d_nodeManager->mkTypeNode(kind::PARAMETRIC_DATATYPE,paramsNodes)));
 }
 
 DatatypeType SelectorType::getDomain() const {
index 980a750d5c8e11f56ce0fe906f9f46e09e25c425..096336b0c5ed0e4804d27dfad795fd9c37c3381a 100644 (file)
@@ -469,6 +469,12 @@ public:
 
   /** Get the name of the sort */
   std::string getName() const;
+
+  /** is parameterized */
+  bool isParameterized() const;
+
+  /** Get the parameter types */
+  std::vector<Type> getParamTypes() const;
 };/* class SortType */
 
 /**
@@ -533,8 +539,19 @@ public:
   /** Get the underlying datatype */
   const Datatype& getDatatype() const;
 
-};/* class DatatypeType */
+  /** is parameterized */
+  bool isParametric() const;
+
+  /** Get the parameter types */
+  std::vector<Type> getParamTypes() const;
 
+  /** Get the arity of the datatype constructor */
+  size_t getArity() const;
+
+  /** Instantiate a datatype using this datatype constructor */
+  DatatypeType instantiate(const std::vector<Type>& params) const;
+
+};/* class DatatypeType */
 
 /**
  * Class encapsulating the constructor type
index a6ca390156d7e80708831cb074b2a7bd7900c498..9283da13abcb5e82d3f649dbb1757d2dc306e2c9 100644 (file)
@@ -136,6 +136,15 @@ std::vector<TypeNode> TypeNode::getArgTypes() const {
   return args;
 }
 
+std::vector<TypeNode> TypeNode::getParamTypes() const {
+  vector<TypeNode> params;
+  Assert( isParametricDatatype() );
+  for(unsigned i = 1, i_end = getNumChildren(); i < i_end; ++i) {
+    params.push_back((*this)[i]);
+  }
+  return params;
+}
+
 TypeNode TypeNode::getRangeType() const {
   if(isTester()) {
     return NodeManager::currentNM()->booleanType();
@@ -185,6 +194,11 @@ bool TypeNode::isDatatype() const {
   return getKind() == kind::DATATYPE_TYPE;
 }
 
+/** Is this a datatype type */
+bool TypeNode::isParametricDatatype() const {
+  return getKind() == kind::PARAMETRIC_DATATYPE;
+}
+
 /** Is this a constructor type */
 bool TypeNode::isConstructor() const {
   return getKind() == kind::CONSTRUCTOR_TYPE;
index 7f6ebd47105d5653a72218ec1b903dd0b07059a3..d6c685a7549acc51c459f33bc3544bfa10207e5a 100644 (file)
@@ -451,6 +451,11 @@ public:
    */
   std::vector<TypeNode> getArgTypes() const;
 
+  /**
+   * Get the paramater types of a parameterized datatype.
+   */
+  std::vector<TypeNode> getParamTypes() const;
+
   /**
    * Get the range type (i.e., the type of the result) of a function,
    * datatype constructor, datatype selector, or datatype tester.
@@ -479,6 +484,9 @@ public:
   /** Is this a datatype type */
   bool isDatatype() const;
 
+  /** Is this a parameterized datatype type */
+  bool isParametricDatatype() const;
+
   /** Is this a constructor type */
   bool isConstructor() const;
 
index b3c253dab5581359eea1474c5ff056b3c27eae7c..3c8d6e1ce838415c7ac45781bfaaecaacf621aff 100644 (file)
@@ -1002,12 +1002,27 @@ restrictedTypePossiblyFunctionLHS[CVC4::Type& t,
 }
     /* named types */
   : identifier[id,check,SYM_SORT]
-    parameterization[check]?
-    { if(check == CHECK_DECLARED ||
+    parameterization[check,types]?
+    { 
+      if(check == CHECK_DECLARED ||
          PARSER_STATE->isDeclared(id, SYM_SORT)) {
-        t = PARSER_STATE->getSort(id);
+        Debug("parser-param") << "param: getSort " << id << " " << types.size() << " " << PARSER_STATE->getArity( id ) 
+                              << " " << PARSER_STATE->isDeclared(id, SYM_SORT) << std::endl;
+        if( types.size()>0 ){
+          t = PARSER_STATE->getSort(id, types);
+        }else{
+          t = PARSER_STATE->getSort(id);
+        }
       } else {
-        t = PARSER_STATE->mkUnresolvedType(id);
+        if( types.empty() ){
+          t = PARSER_STATE->mkUnresolvedType(id);
+          Debug("parser-param") << "param: make unres type " << id << std::endl;
+        }else{
+          t = PARSER_STATE->mkUnresolvedTypeConstructor(id,types);
+          t = SortConstructorType(t).instantiate( types );
+          Debug("parser-param") << "param: make unres param type " << id << " " << types.size() << " " 
+                                << PARSER_STATE->getArity( id ) << std::endl;
+        }
       }
     }
 
@@ -1076,12 +1091,13 @@ restrictedTypePossiblyFunctionLHS[CVC4::Type& t,
     }
   ;
 
-parameterization[CVC4::parser::DeclarationCheck check]
+parameterization[CVC4::parser::DeclarationCheck check, 
+                 std::vector<CVC4::Type>& params]
 @init {
   Type t;
 }
-  : LBRACKET restrictedType[t,check] ( COMMA restrictedType[t,check] )* RBRACKET
-    { UNSUPPORTED("parameterized types not yet supported"); }
+  : LBRACKET restrictedType[t,check] { Debug("parser-param") << "t = " << t << std::endl; params.push_back( t ); }
+    ( COMMA restrictedType[t,check] { Debug("parser-param") << "t = " << t << std::endl; params.push_back( t ); } )* RBRACKET
   ;
 
 bound returns [CVC4::parser::cvc::mySubrangeBound bound]
@@ -1267,6 +1283,7 @@ term[CVC4::Expr& f]
   std::vector<CVC4::Expr> expressions;
   std::vector<unsigned> operators;
   unsigned op;
+  Type t;
 }
   : storeTerm[f] { expressions.push_back(f); }
     ( arithmeticBinop[op] storeTerm[f] { operators.push_back(op); expressions.push_back(f); } )*
@@ -1374,6 +1391,7 @@ postfixTerm[CVC4::Expr& f]
   bool extract = false, left = false;
   std::vector<Expr> args;
   std::string id;
+  Type t;
 }
   : bvTerm[f]
     ( /* array select / bitvector extract */
@@ -1443,7 +1461,9 @@ postfixTerm[CVC4::Expr& f]
           f = EXPR_MANAGER->mkVar(TupleType(f.getType()).getTypes()[k]); }
       )*/
     )*
-    typeAscription[f]?
+    (typeAscription[f, t] 
+     { //f = MK_EXPR(CVC4::kind::APPLY_TYPE_ANNOTATION, MK_CONST(t), f); 
+     } )? 
   ;
 
 bvTerm[CVC4::Expr& f]
@@ -1642,20 +1662,19 @@ simpleTerm[CVC4::Expr& f]
  * Matches (and performs) a type ascription.
  * The f arg is the term to check (it is an input-only argument).
  */
-typeAscription[const CVC4::Expr& f]
+typeAscription[const CVC4::Expr& f, CVC4::Type& t]
 @init {
-  Type t;
 }
   : COLON COLON type[t,CHECK_DECLARED]
-    { if(f.getType() != t) {
-        std::stringstream ss;
-        ss << Expr::setlanguage(language::output::LANG_CVC4)
-           << "type ascription not satisfied\n"
-           << "term:     " << f << '\n'
-           << "has type: " << f.getType() << '\n'
-           << "ascribed: " << t;
-        PARSER_STATE->parseError(ss.str());
-      }
+    { //if(f.getType() != t) {
+      //  std::stringstream ss;
+      //  ss << Expr::setlanguage(language::output::LANG_CVC4)
+      //     << "type ascription not satisfied\n"
+      //     << "term:     " << f << '\n'
+      //     << "has type: " << f.getType() << '\n'
+      //     << "ascribed: " << t;
+      //  PARSER_STATE->parseError(ss.str());
+      //}
     }
   ;
 
@@ -1706,16 +1725,23 @@ iteElseTerm[CVC4::Expr& f]
 datatypeDef[std::vector<CVC4::Datatype>& datatypes]
 @init {
   std::string id, id2;
+  Type t;
+  std::vector< Type > params;
 }
     /* This really needs to be CHECK_NONE, or mutually-recursive datatypes
      * won't work, because this type will already be "defined" as an
      * unresolved type; don't worry, we check below. */
-  : identifier[id,CHECK_NONE,SYM_SORT]
-    ( LBRACKET identifier[id2,CHECK_NONE,SYM_SORT]
-      ( COMMA identifier[id2,CHECK_NONE,SYM_SORT] )* RBRACKET
-      { UNSUPPORTED("parameterized datatypes not yet supported"); }
+  : identifier[id,CHECK_NONE,SYM_SORT] { PARSER_STATE->pushScope(); }
+    ( LBRACKET identifier[id2,CHECK_UNDECLARED,SYM_SORT] { 
+        t = PARSER_STATE->mkSort(id2);
+        params.push_back( t ); 
+      }
+      ( COMMA identifier[id2,CHECK_UNDECLARED,SYM_SORT] { 
+        t = PARSER_STATE->mkSort(id2);
+        params.push_back( t ); } 
+      )* RBRACKET
     )?
-    { datatypes.push_back(Datatype(id));
+    { datatypes.push_back(Datatype(id,params));
       if(!PARSER_STATE->isUnresolvedType(id)) {
         // if not unresolved, must be undeclared
         PARSER_STATE->checkDeclaration(id, CHECK_UNDECLARED, SYM_SORT);
@@ -1723,6 +1749,7 @@ datatypeDef[std::vector<CVC4::Datatype>& datatypes]
     }
     EQUAL_TOK constructorDef[datatypes.back()]
     ( BAR constructorDef[datatypes.back()] )*
+    { PARSER_STATE->popScope(); }
   ;
 
 /**
index 29ade43c1b63059a81864f6baec626f94c408671..efe01fb40adfecc8b2637c30482fae3095cca84a 100644 (file)
@@ -95,6 +95,11 @@ Type Parser::getSort(const std::string& name,
   return t;
 }
 
+size_t Parser::getArity(const std::string& sort_name){
+  Assert( isDeclared(sort_name, SYM_SORT) );
+  return d_declScope->lookupArity(sort_name);
+}
+
 /* Returns true if name is bound to a boolean variable. */
 bool Parser::isBoolean(const std::string& name) {
   return isDeclared(name, SYM_VARIABLE) && getType(name).isBoolean();
@@ -237,6 +242,24 @@ SortType Parser::mkUnresolvedType(const std::string& name) {
   return unresolved;
 }
 
+SortConstructorType Parser::mkUnresolvedTypeConstructor(const std::string& name, 
+                                                        size_t arity)
+{
+  SortConstructorType unresolved = mkSortConstructor(name,arity);
+  d_unresolved.insert(unresolved);
+  return unresolved;
+}
+SortConstructorType Parser::mkUnresolvedTypeConstructor(const std::string& name, 
+                                                        const std::vector<Type>& params){
+  Debug("parser") << "newSortConstructor(P)(" << name << ", " << params.size() << ")"
+                  << std::endl;
+  SortConstructorType unresolved = d_exprManager->mkSortConstructor(name, params.size());
+  defineType(name, params, unresolved);
+  Type t = getSort( name, params );
+  d_unresolved.insert(unresolved);
+  return unresolved;
+}
+
 bool Parser::isUnresolvedType(const std::string& name) {
   if(!isDeclared(name, SYM_SORT)) {
     return false;
@@ -260,7 +283,12 @@ Parser::mkMutualDatatypeTypes(const std::vector<Datatype>& datatypes) {
     if(isDeclared(name, SYM_SORT)) {
       throw ParserException(name + " already declared");
     }
-    defineType(name, t);
+    if( t.isParametric() ){
+      std::vector< Type > paramTypes = t.getParamTypes();
+      defineType(name, paramTypes, t );
+    }else{
+      defineType(name, t);
+    }
     for(Datatype::const_iterator j = dt.begin(),
           j_end = dt.end();
         j != j_end;
index 6509b192be0f0b20afd5442ef57909d2b8d18da5..6d55e195ee8481f568524690a67f930b1cc7ea30 100644 (file)
@@ -150,7 +150,7 @@ class CVC4_PUBLIC Parser {
    * depend on mkMutualDatatypeTypes() to check everything and clear
    * this out.
    */
-  std::set<SortType> d_unresolved;
+  std::set<Type> d_unresolved;
 
   /**
    * "Preemption commands": extra commands implied by subterms that
@@ -254,6 +254,11 @@ public:
   Type getSort(const std::string& sort_name,
                const std::vector<Type>& params);
 
+  /**
+   * Returns arity of a (parameterized) sort, given a name and args.
+   */
+  size_t getArity(const std::string& sort_name);
+
   /**
    * Checks if a symbol has been declared.
    * @param name the symbol name
@@ -367,6 +372,14 @@ public:
    */
   SortType mkUnresolvedType(const std::string& name);
 
+  /**
+   * Creates a new "unresolved type," used only during parsing.
+   */
+  SortConstructorType mkUnresolvedTypeConstructor(const std::string& name, 
+                                                  size_t arity);
+  SortConstructorType mkUnresolvedTypeConstructor(const std::string& name, 
+                                                  const std::vector<Type>& params);
+
   /**
    * Returns true IFF name is an unresolved type.
    */
index eea15d6b621312b029c4c44c97445ed2991aad99..df8cb6eb8df81a3978f5af47e3af5c90c919caad 100644 (file)
@@ -43,7 +43,8 @@ public:
         return RewriteResponse(REWRITE_DONE,
                                NodeManager::currentNM()->mkConst(result));
       } else {
-        const Datatype& dt = in[0].getType().getConst<Datatype>();
+        //const Datatype& dt = in[0].getType().getConst<Datatype>();
+        const Datatype& dt = DatatypeType(in[0].getType().toType()).getDatatype();
         if(dt.getNumConstructors() == 1) {
           // only one constructor, so it must be
           Debug("datatypes-rewrite") << "DatatypesRewriter::postRewrite: "
index 37a65a2b0ba705dd92f5a87e88eddc607971a609..d08e3875cf400167958841049196528df24caf17 100644 (file)
@@ -53,4 +53,13 @@ well-founded DATATYPE_TYPE \
     "%TYPE%.getConst<Datatype>().mkGroundTerm()" \
     "util/datatype.h"
 
+operator PARAMETRIC_DATATYPE 1: "parametric datatype"
+cardinality PARAMETRIC_DATATYPE \
+    "DatatypeType(%TYPE%.toType()).getDatatype().getCardinality()" \
+    "util/datatype.h"
+well-founded PARAMETRIC_DATATYPE\
+    "DatatypeType(%TYPE%.toType()).getDatatype().isWellFounded()" \
+    "DatatypeType(%TYPE%.toType()).getDatatype().mkGroundTerm()" \
+    "util/datatype.h"
+    
 endtheory
index 6808ef749d00277fc52fc2c2178dcdab8edd84fa..2f0b82f7c48f751d1f51b536457a00381c9f6f91 100644 (file)
@@ -851,7 +851,8 @@ void TheoryDatatypes::addTermToLabels( Node t ) {
         const Datatype& dt = ((DatatypeType)(t.getType()).toType()).getDatatype();
         if( dt.getNumConstructors()==1 ){
           Node tester = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[0].getTester() ), t );
-          addTester( tester );
+          lbl->push_back( tester );
+          d_checkMap[ t ] = true;
           d_em.addNodeAxiom( tester, Reason::idt_texhaust );
         }
         d_labels.insertDataFromContextMemory(tmp, lbl);
index bc1581f140d1b6ccfaee887a042b8b3efe18c120..dc2e95f9d5be9dc826c99e23be36d79de51552de 100644 (file)
@@ -32,6 +32,70 @@ namespace expr {
 namespace theory {
 namespace datatypes {
 
+class Matcher
+{
+private:
+  std::vector< TypeNode > d_types;
+  std::vector< TypeNode > d_match;
+public:
+  Matcher(){}
+  Matcher( DatatypeType dt ){
+    std::vector< Type > argTypes = dt.getParamTypes();
+    addTypes( argTypes );
+  }
+  ~Matcher(){}
+
+  void addType( Type t ){
+    d_types.push_back( TypeNode::fromType( t ) );
+    d_match.push_back( TypeNode::null() );
+  }
+  void addTypes( std::vector< Type > types ){
+    for( int i=0; i<(int)types.size(); i++ ){
+      addType( types[i] );
+    }
+  }
+
+  bool doMatching( TypeNode base, TypeNode match ){
+    std::vector< TypeNode >::iterator i = std::find( d_types.begin(), d_types.end(), base );
+    if( i!=d_types.end() ){
+      int index = i - d_types.begin();
+      if( !d_match[index].isNull() && d_match[index]!=match ){
+        return false;
+      }else{
+        d_match[ i - d_types.begin() ] = match;
+        return true;
+      }
+    }else if( base==match ){
+      return true;
+    }else if( base.getKind()!=match.getKind() || base.getNumChildren()!=match.getNumChildren() ){
+      return false;
+    }else{
+      for( int i=0; i<(int)base.getNumChildren(); i++ ){
+        if( !doMatching( base[i], match[i] ) ){
+          return false;
+        }
+      }
+      return true;
+    }
+  }
+
+  TypeNode getMatch( unsigned int i ){ return d_match[i]; }
+  void getTypes( std::vector<Type>& types ) { 
+    types.clear();
+    for( int i=0; i<(int)d_match.size(); i++ ){
+      types.push_back( d_types[i].toType() );
+    }
+  }
+  void getMatches( std::vector<Type>& types ) { 
+    types.clear();
+    for( int i=0; i<(int)d_match.size(); i++ ){
+      Assert( !d_match[i].isNull() ); //verify that all types have been set
+      types.push_back( d_match[i].toType() );
+    }
+  }
+};
+
+
 typedef expr::Attribute<expr::attr::DatatypeConstructorTypeGroundTermTag, Node> GroundTermAttr;
 
 struct DatatypeConstructorTypeRule {
@@ -39,24 +103,43 @@ struct DatatypeConstructorTypeRule {
     throw(TypeCheckingExceptionPrivate) {
     Assert(n.getKind() == kind::APPLY_CONSTRUCTOR);
     TypeNode consType = n.getOperator().getType(check);
-    if(check) {
-      Debug("typecheck-idt") << "typecheck cons: " << n << " " << n.getNumChildren() << std::endl;
-      Debug("typecheck-idt") << "cons type: " << consType << " " << consType.getNumChildren() << std::endl;
-      if(n.getNumChildren() != consType.getNumChildren() - 1) {
-        throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the constructor type");
-      }
-      TNode::iterator child_it = n.begin();
-      TNode::iterator child_it_end = n.end();
-      TypeNode::iterator tchild_it = consType.begin();
+    Type t = consType.getConstructorRangeType().toType();
+    Assert( t.isDatatype() );
+    DatatypeType dt = DatatypeType(t);
+    TNode::iterator child_it = n.begin();
+    TNode::iterator child_it_end = n.end();
+    TypeNode::iterator tchild_it = consType.begin();
+    if( ( dt.isParametric() || check ) && n.getNumChildren() != consType.getNumChildren() - 1 ){
+      throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the constructor type");
+    }
+    if( dt.isParametric() ){
+      Debug("typecheck-idt") << "typecheck parameterized datatype " << n << std::endl;
+      Matcher m( dt );
       for(; child_it != child_it_end; ++child_it, ++tchild_it) {
         TypeNode childType = (*child_it).getType(check);
-        Debug("typecheck-idt") << "typecheck cons arg: " << childType << " " << (*tchild_it) << std::endl;
-        if(childType != *tchild_it) {
-          throw TypeCheckingExceptionPrivate(n, "bad type for constructor argument");
+        if( !m.doMatching( *tchild_it, childType ) ){
+          throw TypeCheckingExceptionPrivate(n, "matching failed for parameterized constructor");
+        }
+      }
+      std::vector< Type > instTypes;
+      m.getMatches( instTypes );
+      TypeNode range = TypeNode::fromType( dt.instantiate( instTypes ) );
+      Debug("typecheck-idt") << "Return " << range << std::endl;
+      return range;
+    }else{
+      if(check) {
+        Debug("typecheck-idt") << "typecheck cons: " << n << " " << n.getNumChildren() << std::endl;
+        Debug("typecheck-idt") << "cons type: " << consType << " " << consType.getNumChildren() << std::endl;
+        for(; child_it != child_it_end; ++child_it, ++tchild_it) {
+          TypeNode childType = (*child_it).getType(check);
+          Debug("typecheck-idt") << "typecheck cons arg: " << childType << " " << (*tchild_it) << std::endl;
+          if(childType != *tchild_it) {
+            throw TypeCheckingExceptionPrivate(n, "bad type for constructor argument");
+          }
         }
       }
+      return consType.getConstructorRangeType();
     }
-    return consType.getConstructorRangeType();
   }
 };/* struct DatatypeConstructorTypeRule */
 
@@ -65,18 +148,38 @@ struct DatatypeSelectorTypeRule {
     throw(TypeCheckingExceptionPrivate) {
     Assert(n.getKind() == kind::APPLY_SELECTOR);
     TypeNode selType = n.getOperator().getType(check);
-    Debug("typecheck-idt") << "typecheck sel: " << n << std::endl;
-    Debug("typecheck-idt") << "sel type: " << selType << std::endl;
-    if(check) {
-      if(n.getNumChildren() != 1) {
-        throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the selector type");
-      }
+    Type t = selType[0].toType();
+    Assert( t.isDatatype() );
+    DatatypeType dt = DatatypeType(t);
+    if( ( dt.isParametric() || check ) && n.getNumChildren() != 1 ){
+      throw TypeCheckingExceptionPrivate(n, "number of arguments does not match the selector type");
+    }
+    if( dt.isParametric() ){
+      Debug("typecheck-idt") << "typecheck parameterized sel: " << n << std::endl;
+      Matcher m( dt );
       TypeNode childType = n[0].getType(check);
-      if(selType[0] != childType) {
-        throw TypeCheckingExceptionPrivate(n, "bad type for selector argument");
+      if( !m.doMatching( selType[0], childType ) ){
+        throw TypeCheckingExceptionPrivate(n, "matching failed for selector argument of parameterized datatype");
       }
+      std::vector< Type > types, matches;
+      m.getTypes( types );
+      m.getMatches( matches );
+      Type range = selType[1].toType();
+      range = range.substitute( types, matches );
+      Debug("typecheck-idt") << "Return " << range << std::endl;
+      return TypeNode::fromType( range );
+    }else{
+      if(check) {
+        Debug("typecheck-idt") << "typecheck sel: " << n << std::endl;
+        Debug("typecheck-idt") << "sel type: " << selType << std::endl;
+        TypeNode childType = n[0].getType(check);
+        if(selType[0] != childType) {
+          Debug("typecheck-idt") << "ERROR: " << selType[0].getKind() << " " << childType.getKind() << std::endl;
+          throw TypeCheckingExceptionPrivate(n, "bad type for selector argument");
+        }
+      }
+      return selType[1];
     }
-    return selType[1];
   }
 };/* struct DatatypeSelectorTypeRule */
 
@@ -90,10 +193,21 @@ struct DatatypeTesterTypeRule {
       }
       TypeNode testType = n.getOperator().getType(check);
       TypeNode childType = n[0].getType(check);
-      Debug("typecheck-idt") << "typecheck test: " << n << std::endl;
-      Debug("typecheck-idt") << "test type: " << testType << std::endl;
-      if(testType[0] != childType) {
-        throw TypeCheckingExceptionPrivate(n, "bad type for tester argument");
+      Type t = testType[0].toType();
+      Assert( t.isDatatype() );
+      DatatypeType dt = DatatypeType(t);
+      if( dt.isParametric() ){
+        Debug("typecheck-idt") << "typecheck parameterized tester: " << n << std::endl;
+        Matcher m( dt );
+        if( !m.doMatching( testType[0], childType ) ){
+          throw TypeCheckingExceptionPrivate(n, "matching failed for tester argument of parameterized datatype");
+        }
+      }else{
+        Debug("typecheck-idt") << "typecheck test: " << n << std::endl;
+        Debug("typecheck-idt") << "test type: " << testType << std::endl;
+        if(testType[0] != childType) {
+          throw TypeCheckingExceptionPrivate(n, "bad type for tester argument");
+        }
       }
     }
     return nodeManager->booleanType();
@@ -140,7 +254,8 @@ struct ConstructorProperties {
     // Constructors within the same Datatype could share the same
     // type.  So we scan through the datatype to find one that
     // matches.
-    const Datatype& dt = type[type.getNumChildren() - 1].getConst<Datatype>();
+    //const Datatype& dt = type[type.getNumChildren() - 1].getConst<Datatype>();
+    const Datatype& dt = DatatypeType(type[type.getNumChildren() - 1].toType()).getDatatype();
     for(Datatype::const_iterator i = dt.begin(),
           i_end = dt.end();
         i != i_end;
index ab52e7f937a90a94f52253e0016970d22e73b59f..ecb0896580a81181e1c5b8213439a5de80b38b60 100644 (file)
@@ -54,13 +54,15 @@ const Datatype& Datatype::datatypeOf(Expr item) {
   TypeNode t = Node::fromExpr(item).getType();
   switch(t.getKind()) {
   case kind::CONSTRUCTOR_TYPE:
-    return t[t.getNumChildren() - 1].getConst<Datatype>();
+    //return t[t.getNumChildren() - 1].getConst<Datatype>();
+    return DatatypeType(t[t.getNumChildren() - 1].toType()).getDatatype();
   case kind::SELECTOR_TYPE:
   case kind::TESTER_TYPE:
-    return t[0].getConst<Datatype>();
+    //return t[0].getConst<Datatype>();
+    return DatatypeType(t[0].toType()).getDatatype();
   default:
     Unhandled("arg must be a datatype constructor, selector, or tester");
-  }
+  } 
 }
 
 size_t Datatype::indexOf(Expr item) {
@@ -77,9 +79,12 @@ size_t Datatype::indexOf(Expr item) {
 void Datatype::resolve(ExprManager* em,
                        const std::map<std::string, DatatypeType>& resolutions,
                        const std::vector<Type>& placeholders,
-                       const std::vector<Type>& replacements)
+                       const std::vector<Type>& replacements,
+                       const std::vector< SortConstructorType >& paramTypes,
+                       const std::vector< DatatypeType >& paramReplacements)
   throw(AssertionException, DatatypeResolutionException) {
 
+  //cout << "resolve " << *this << "..." << std::endl;
   AssertArgument(em != NULL, "cannot resolve a Datatype with a NULL expression manager");
   CheckArgument(!d_resolved, "cannot resolve a Datatype twice");
   AssertArgument(resolutions.find(d_name) != resolutions.end(),
@@ -92,13 +97,15 @@ void Datatype::resolve(ExprManager* em,
   d_resolved = true;
   size_t index = 0;
   for(iterator i = begin(), i_end = end(); i != i_end; ++i) {
-    (*i).resolve(em, self, resolutions, placeholders, replacements);
+    (*i).resolve(em, self, resolutions, placeholders, replacements, paramTypes, paramReplacements);
     Assert((*i).isResolved());
     Node::fromExpr((*i).d_constructor).setAttribute(DatatypeIndexAttr(), index);
     Node::fromExpr((*i).d_tester).setAttribute(DatatypeIndexAttr(), index++);
   }
   d_self = self;
   Assert(index == getNumConstructors());
+
+  //cout << "done resolve " << *this << std::endl;
 }
 
 void Datatype::addConstructor(const Constructor& c) {
@@ -263,10 +270,16 @@ Expr Datatype::mkGroundTerm() const throw(AssertionException) {
 
 DatatypeType Datatype::getDatatypeType() const throw(AssertionException) {
   CheckArgument(isResolved(), *this, "Datatype must be resolved to get its DatatypeType");
-  Assert(!d_self.isNull());
+  Assert(!d_self.isNull() && !DatatypeType(d_self).isParametric());
   return DatatypeType(d_self);
 }
 
+DatatypeType Datatype::getDatatypeType(const std::vector<Type>& params) const throw(AssertionException) {
+  CheckArgument(isResolved(), *this, "Datatype must be resolved to get its DatatypeType");
+  Assert(!d_self.isNull() && DatatypeType(d_self).isParametric());
+  return DatatypeType(d_self).instantiate(params);
+}
+
 bool Datatype::operator==(const Datatype& other) const throw() {
   // two datatypes are == iff the name is the same and they have
   // exactly matching constructors (in the same order)
@@ -349,8 +362,13 @@ const Datatype::Constructor& Datatype::operator[](size_t index) const {
 void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self,
                                     const std::map<std::string, DatatypeType>& resolutions,
                                     const std::vector<Type>& placeholders,
-                                    const std::vector<Type>& replacements)
+                                    const std::vector<Type>& replacements,
+                                    const std::vector< SortConstructorType >& paramTypes,
+                                    const std::vector< DatatypeType >& paramReplacements)
   throw(AssertionException, DatatypeResolutionException) {
+
+  //cout << "resolve " << *this << "..." << std::endl;
+
   AssertArgument(em != NULL, "cannot resolve a Datatype with a NULL expression manager");
   CheckArgument(!isResolved(),
                 "cannot resolve a Datatype constructor twice; "
@@ -383,6 +401,9 @@ void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self,
       if(!placeholders.empty()) {
         range = range.substitute(placeholders, replacements);
       }
+      if(!paramTypes.empty() ){
+        range = doParametricSubstitution( range, paramTypes, paramReplacements );
+      }
       (*i).d_selector = em->mkVar((*i).d_name, em->mkSelectorType(self, range));
     }
     Node::fromExpr((*i).d_selector).setAttribute(DatatypeIndexAttr(), index++);
@@ -403,6 +424,37 @@ void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self,
   for(iterator i = begin(), i_end = end(); i != i_end; ++i) {
     (*i).d_constructor = d_constructor;
   }
+
+  //cout << "done resolve " << *this << std::endl;
+}
+
+Type Datatype::Constructor::doParametricSubstitution( Type range, 
+                                  const std::vector< SortConstructorType >& paramTypes, 
+                                  const std::vector< DatatypeType >& paramReplacements ){
+  TypeNode typn = TypeNode::fromType( range );
+  if(typn.getNumChildren() == 0) {
+    return range;
+  } else {
+    std::vector< Type > origChildren;
+    std::vector< Type > children;
+    for(TypeNode::const_iterator i = typn.begin(), iend = typn.end();i != iend; ++i) {
+      origChildren.push_back( (*i).toType() );
+      children.push_back( doParametricSubstitution( (*i).toType(), paramTypes, paramReplacements ) );
+    }
+    for( int i=0; i<(int)paramTypes.size(); i++ ){
+      if( paramTypes[i].getArity()==origChildren.size() ){
+        Type tn = paramTypes[i].instantiate( origChildren );
+        if( range==tn ){
+          return paramReplacements[i].instantiate( children );
+        }
+      }
+    }
+    NodeBuilder<> nb(typn.getKind());
+    for( int i=0; i<(int)children.size(); i++ ){
+      nb << TypeNode::fromType( children[i] );
+    }
+    return nb.constructTypeNode().toType();
+  }
 }
 
 Datatype::Constructor::Constructor(std::string name, std::string tester) :
@@ -615,6 +667,10 @@ Expr Datatype::Constructor::Arg::getConstructor() const {
   return d_constructor;
 }
 
+Type Datatype::Constructor::Arg::getSelectorType() const{
+  return getSelector().getType();
+}
+
 bool Datatype::Constructor::Arg::isUnresolvedSelf() const throw() {
   return d_selector.isNull() && d_name.size() == d_name.find('\0') + 1;
 }
index 7d9ae6f3921eeae9982e0e345a34180482afff4d..abc9e3258c65679b5515e07813e7e271773ed17f 100644 (file)
@@ -174,6 +174,12 @@ public:
        */
       Expr getConstructor() const;
 
+      /**
+       * Get the selector for this constructor argument; this call is
+       * only permitted after resolution.
+       */
+      Type getSelectorType() const;
+
       /**
        * Get the name of the type of this constructor argument
        * (Datatype field).  Can be used for not-yet-resolved Datatypes
@@ -204,10 +210,16 @@ public:
     void resolve(ExprManager* em, DatatypeType self,
                  const std::map<std::string, DatatypeType>& resolutions,
                  const std::vector<Type>& placeholders,
-                 const std::vector<Type>& replacements)
+                 const std::vector<Type>& replacements,
+                 const std::vector< SortConstructorType >& paramTypes,
+                 const std::vector< DatatypeType >& paramReplacements)
       throw(AssertionException, DatatypeResolutionException);
     friend class Datatype;
 
+    /** */
+    Type doParametricSubstitution( Type range, 
+                                   const std::vector< SortConstructorType >& paramTypes, 
+                                   const std::vector< DatatypeType >& paramReplacements );
   public:
     /**
      * Create a new Datatype constructor with the given name for the
@@ -314,6 +326,7 @@ public:
 
 private:
   std::string d_name;
+  std::vector<Type> d_params;
   std::vector<Constructor> d_constructors;
   bool d_resolved;
   Type d_self;
@@ -330,14 +343,16 @@ private:
   void resolve(ExprManager* em,
                const std::map<std::string, DatatypeType>& resolutions,
                const std::vector<Type>& placeholders,
-               const std::vector<Type>& replacements)
+               const std::vector<Type>& replacements,
+               const std::vector< SortConstructorType >& paramTypes,
+               const std::vector< DatatypeType >& paramReplacements)
     throw(AssertionException, DatatypeResolutionException);
   friend class ExprManager;// for access to resolve()
 
 public:
 
   /** Create a new Datatype of the given name. */
-  inline explicit Datatype(std::string name);
+  inline explicit Datatype(std::string name, std::vector<Type>& params);
 
   /** Add a constructor to this Datatype. */
   void addConstructor(const Constructor& c);
@@ -347,6 +362,11 @@ public:
   /** Get the number of constructors (so far) for this Datatype. */
   inline size_t getNumConstructors() const throw();
 
+  /** Get the nubmer of parameters */
+  inline size_t getNumParameters() const throw();
+  /** Get parameter */
+  inline Type getParameter( unsigned int i ) const;
+
   /**
    * Return the cardinality of this datatype (the sum of the
    * cardinalities of its constructors).  The Datatype must be
@@ -381,6 +401,12 @@ public:
    */
   DatatypeType getDatatypeType() const throw(AssertionException);
 
+  /**
+   * Get the DatatypeType associated to this (parameterized) Datatype.  Can only be
+   * called post-resolution.
+   */
+  DatatypeType getDatatypeType(const std::vector<Type>& params) const throw(AssertionException);
+
   /**
    * Return true iff the two Datatypes are the same.
    *
@@ -466,8 +492,9 @@ inline std::string Datatype::UnresolvedType::getName() const throw() {
   return d_name;
 }
 
-inline Datatype::Datatype(std::string name) :
+inline Datatype::Datatype(std::string name, std::vector<Type>& params) :
   d_name(name),
+  d_params(params),
   d_constructors(),
   d_resolved(false),
   d_self() {
@@ -481,6 +508,14 @@ inline size_t Datatype::getNumConstructors() const throw() {
   return d_constructors.size();
 }
 
+inline size_t Datatype::getNumParameters() const throw() {
+  return d_params.size();
+}
+
+inline Type Datatype::getParameter( unsigned int i ) const {
+  return d_params[i];
+}
+
 inline bool Datatype::operator!=(const Datatype& other) const throw() {
   return !(*this == other);
 }