Fair datatype enumeration.
authorajreynol <andrew.j.reynolds@gmail.com>
Thu, 25 Sep 2014 18:39:13 +0000 (20:39 +0200)
committerajreynol <andrew.j.reynolds@gmail.com>
Thu, 25 Sep 2014 18:39:13 +0000 (20:39 +0200)
src/theory/datatypes/type_enumerator.h

index a514460c7050715192e64bfa3ff73541ffc2b4c8..dc2a83841bd2d45fbf1997c6b524432c4da69a9f 100644 (file)
@@ -28,47 +28,146 @@ namespace CVC4 {
 namespace theory {
 namespace datatypes {
 
+
 class DatatypesEnumerator : public TypeEnumeratorBase<DatatypesEnumerator> {
   /** The datatype we're enumerating */
   const Datatype& d_datatype;
-  /** The datatype constructor we're currently enumerating */
-  size_t d_ctor;
-  /** The "first" constructor to consider; it's non-recursive */
-  size_t d_zeroCtor;
-  /** Delegate enumerators for the arguments of the current constructor */
-  TypeEnumerator** d_argEnumerators;
   /** type */
   TypeNode d_type;
-
-  /** Allocate and initialize the delegate enumerators */
-  void newEnumerators() {
-    d_argEnumerators = new TypeEnumerator*[d_datatype[d_ctor].getNumArgs()];
-    for(size_t a = 0; a < d_datatype[d_ctor].getNumArgs(); ++a) {
-      d_argEnumerators[a] = NULL;
+  /** The datatype constructor we're currently enumerating */
+  unsigned d_ctor;
+  /** The "first" constructor to consider; it's non-recursive */
+  unsigned d_zeroCtor;
+  /** list of type enumerators (one for each type in a selector argument) */
+  std::map< TypeNode, unsigned > d_te_index;
+  std::vector< TypeEnumerator > d_children;
+  /** terms produced for types */
+  std::map< TypeNode, std::vector< Node > > d_terms;
+  /** arg type of each selector, for each constructor */
+  std::vector< std::vector< TypeNode > > d_sel_types;
+  /** current index for each argument, for each constructor */
+  std::vector< std::vector< unsigned > > d_sel_index;
+  /** current sum of argument indicies for each constructor */
+  std::vector< int > d_sel_sum;
+  /** current bound on the number of times we can iterate argument enumerators */
+  unsigned d_size_limit;
+
+  Node getTermEnum( TypeNode tn, unsigned i ){
+    if( i<d_terms[tn].size() ){
+      return d_terms[tn][i];
+    }else{
+      Debug("dt-enum-debug") << "get term enum " << tn << " " << i << std::endl;
+      std::map< TypeNode, unsigned >::iterator it = d_te_index.find( tn );
+      unsigned tei;
+      if( it==d_te_index.end() ){
+        //initialize child enumerator for type
+        tei = d_children.size();
+        d_te_index[tn] = tei;
+        d_children.push_back( TypeEnumerator( tn ) );
+        d_terms[tn].push_back( *d_children[tei] );
+      }else{
+        tei = it->second;
+      }
+      //enumerate terms until index is reached
+      while( i>=d_terms[tn].size() ){
+        ++d_children[tei];
+        if( d_children[tei].isFinished() ){
+          Debug("dt-enum-debug") << "...fail term enum " << tn << " " << i << std::endl;
+          return Node::null();
+        }
+        d_terms[tn].push_back( *d_children[tei] );
+      }
+      Debug("dt-enum-debug") << "...return term enum " << tn << " " << i << " : " << d_terms[tn][i] << std::endl;
+      return d_terms[tn][i];
     }
   }
 
-  /** Delete the delegate enumerators */
-  void deleteEnumerators() {
-    if(d_argEnumerators != NULL) {
-      for(size_t a = 0; a < d_datatype[d_ctor].getNumArgs(); ++a) {
-        delete d_argEnumerators[a];
+  bool increment( unsigned index ){
+    Debug("dt-enum") << "Incrementing " << d_type << " " << d_ctor << " at size " << d_sel_sum[index] << "/" << d_size_limit << std::endl;
+    if( d_sel_sum[index]==-1 ){
+      //first time
+      d_sel_sum[index] = 0;
+      //special case: no children to iterate
+      if( d_sel_types[index].size()==0 ){
+        Debug("dt-enum") << "...success (nc) = " << (d_size_limit==0) << std::endl;
+        return d_size_limit==0;
+      }else{
+        Debug("dt-enum") << "...success" << std::endl;
+        return true;
       }
-      delete [] d_argEnumerators;
-      d_argEnumerators = NULL;
+    }else{
+      unsigned i = 0;
+      while( i < d_sel_index[index].size() ){
+        //increment if the sum of iterations on arguments is less than the limit
+        if( d_sel_sum[index]<(int)d_size_limit ){
+          //also check if child enumerator has enough terms
+          if( !getTermEnum( d_sel_types[index][i], d_sel_index[index][i]+1 ).isNull() ){
+            Debug("dt-enum") << "...success increment child " << i << std::endl;
+            d_sel_index[index][i]++;
+            d_sel_sum[index]++;
+            return true;
+          }
+        }
+        Debug("dt-enum") << "......failed increment child " << i << std::endl;
+        //reset child, iterate next
+        d_sel_sum[index] -= d_sel_index[index][i];
+        d_sel_index[index][i] = 0;
+        i++;
+      }
+      Debug("dt-enum") << "...failure." << std::endl;
+      return false;
     }
   }
 
+  Node getCurrentTerm( unsigned index ){
+    Debug("dt-enum-debug") << "Get current term at " << index << " " << d_type << "..." << std::endl;
+    DatatypeConstructor ctor = d_datatype[index];
+    Debug("dt-enum-debug") << "Check last term..." << std::endl;
+    //we first check if the last argument (which is forced to make sum of iterated arguments equal to d_size_limit) is defined
+    Node lc;
+    if( ctor.getNumArgs()>0 ){
+      lc = getTermEnum( d_sel_types[index][ctor.getNumArgs()-1], d_size_limit - d_sel_sum[index] );
+      if( lc.isNull() ){
+        Debug("dt-enum-debug") << "Current infeasible." << std::endl;
+        return Node::null();
+      }
+    }
+    Debug("dt-enum-debug") << "Get constructor..." << std::endl;
+    NodeBuilder<> b(kind::APPLY_CONSTRUCTOR);
+    Type typ;
+    if( d_datatype.isParametric() ){
+      typ = ctor.getSpecializedConstructorType(d_type.toType());
+      b << NodeManager::currentNM()->mkNode(kind::APPLY_TYPE_ASCRIPTION,
+                                            NodeManager::currentNM()->mkConst(AscriptionType(typ)), Node::fromExpr( ctor.getConstructor() ) );
+    }else{
+      b << ctor.getConstructor();
+    }
+    Debug("dt-enum-debug") << "Get arguments..." << std::endl;
+    if( ctor.getNumArgs()>0 ){
+      Assert( index<d_sel_types.size() );
+      Assert( index<d_sel_index.size() );
+      Assert( d_sel_types[index].size()==ctor.getNumArgs() );
+      Assert( d_sel_index[index].size()==ctor.getNumArgs()-1 );
+      for( int i=0; i<(int)(ctor.getNumArgs()-1); i++ ){
+        Node c = getTermEnum( d_sel_types[index][i], d_sel_index[index][i] );
+        Assert( !c.isNull() );
+        b << c;
+      }
+      b << lc;
+    }
+    Node nnn = Node(b);
+    Debug("dt-enum-debug") << "Return... " <<  nnn  << std::endl;
+    return nnn;
+  }
 
 public:
 
   DatatypesEnumerator(TypeNode type) throw() :
     TypeEnumeratorBase<DatatypesEnumerator>(type),
     d_datatype(DatatypeType(type.toType()).getDatatype()),
+    d_type(type),
     d_ctor(0),
-    d_zeroCtor(0),
-    d_argEnumerators(NULL),
-    d_type(type) {
+    d_zeroCtor(0) {
 
     //Assert(type.isDatatype());
     Debug("te") << "datatype is datatype? " << type.isDatatype() << std::endl;
@@ -79,96 +178,104 @@ public:
     Node t = type.mkGroundTerm();
     Assert( t.getKind()==kind::APPLY_CONSTRUCTOR );
     d_zeroCtor = Datatype::indexOf( t.getOperator().toExpr() );
-
-    /* start with the non-recursive constructor */
+    /* start with the constructor for which a ground term is constructed */
     d_ctor = d_zeroCtor;
 
-    /* allocate space for the enumerators */
-    newEnumerators();
+    for( unsigned i=0; i<d_datatype.getNumConstructors(); ++i ){
+      d_sel_types.push_back( std::vector< TypeNode >() );
+      d_sel_index.push_back( std::vector< unsigned >() );
+      d_sel_sum.push_back( -1 );
+      DatatypeConstructor ctor = d_datatype[i];
+      Type typ;
+      if( d_datatype.isParametric() ){
+        typ = ctor.getSpecializedConstructorType(d_type.toType());
+      }
+      for( unsigned a = 0; a < ctor.getNumArgs(); ++a ){
+        TypeNode tn;
+        if( d_datatype.isParametric() ){
+          tn = TypeNode::fromType( typ )[a];
+        }else{
+          tn = Node::fromExpr(ctor[a].getSelector()).getType()[1];
+        }
+        d_sel_types[i].push_back( tn );
+        d_sel_index[i].push_back( 0 );
+      }
+      if( !d_sel_index[i].empty() ){
+        d_sel_index[i].pop_back();
+      }
+    }
+    d_size_limit = 0;
+    //set up initial conditions (should always succeed)
+    bool init_inc = increment( d_ctor );
+    AlwaysAssert( init_inc );
   }
 
   DatatypesEnumerator(const DatatypesEnumerator& de) throw() :
     TypeEnumeratorBase<DatatypesEnumerator>(de.getType()),
     d_datatype(de.d_datatype),
+    d_type(de.d_type),
     d_ctor(de.d_ctor),
-    d_zeroCtor(de.d_zeroCtor),
-    d_argEnumerators(NULL),
-    d_type(de.d_type) {
+    d_zeroCtor(de.d_zeroCtor) {
 
-    if(de.d_argEnumerators != NULL) {
-      newEnumerators();
-      for(size_t a = 0; a < d_datatype[d_ctor].getNumArgs(); ++a) {
-        if(de.d_argEnumerators[a] != NULL) {
-          d_argEnumerators[a] = new TypeEnumerator(*de.d_argEnumerators[a]);
-        }
-      }
+    for( std::map< TypeNode, unsigned >::const_iterator it = de.d_te_index.begin(); it != de.d_te_index.end(); ++it ){
+      d_te_index[it->first] = it->second;
+    }
+    for( std::map< TypeNode, std::vector< Node > >::const_iterator it = de.d_terms.begin(); it != de.d_terms.end(); ++it ){
+      d_terms[it->first].insert( d_terms[it->first].end(), it->second.begin(), it->second.end() );
+    }
+    for( unsigned i=0; i<de.d_sel_types.size(); i++ ){
+      d_sel_types.push_back( std::vector< TypeNode >() );
+      d_sel_types[i].insert( d_sel_types[i].end(), de.d_sel_types[i].begin(), de.d_sel_types[i].end() );
     }
+    for( unsigned i=0; i<de.d_sel_index.size(); i++ ){
+      d_sel_index.push_back( std::vector< unsigned >() );
+      d_sel_index[i].insert( d_sel_index[i].end(), de.d_sel_index[i].begin(), de.d_sel_index[i].end() );
+    }
+
+    d_children.insert( d_children.end(), de.d_children.begin(), de.d_children.end() );
+    d_sel_sum.insert( d_sel_sum.end(), de.d_sel_sum.begin(), de.d_sel_sum.end() );
+    d_size_limit = de.d_size_limit;
   }
 
   ~DatatypesEnumerator() throw() {
-    deleteEnumerators();
   }
 
   Node operator*() throw(NoMoreValuesException) {
+    Debug("dt-enum-debug") << ": get term " << this << std::endl;
     if(d_ctor < d_datatype.getNumConstructors()) {
-      DatatypeConstructor ctor = d_datatype[d_ctor];
-      NodeBuilder<> b(kind::APPLY_CONSTRUCTOR);
-      Type typ;
-      if( d_datatype.isParametric() ){
-        typ = d_datatype[d_ctor].getSpecializedConstructorType(d_type.toType());
-        b << NodeManager::currentNM()->mkNode(kind::APPLY_TYPE_ASCRIPTION,
-                                              NodeManager::currentNM()->mkConst(AscriptionType(typ)), Node::fromExpr( ctor.getConstructor() ) );
-      }else{
-        b << ctor.getConstructor();
-      }
-      try {
-        for(size_t a = 0; a < d_datatype[d_ctor].getNumArgs(); ++a) {
-          if(d_argEnumerators[a] == NULL) {
-            if( d_datatype.isParametric() ){
-              d_argEnumerators[a] = new TypeEnumerator(TypeNode::fromType( typ )[a]);
-            }else{
-              d_argEnumerators[a] = new TypeEnumerator(Node::fromExpr(d_datatype[d_ctor][a].getSelector()).getType()[1]);
-            }
-          }
-          b << **d_argEnumerators[a];
-        }
-      } catch(NoMoreValuesException&) {
-        InternalError();
-      }
-      Node nnn = Node(b);
-      //if( nnn.getType()!=d_type || !nnn.getType().isComparableTo(d_type) ){
-      //  Debug("dt-warn") << "WARNING : Enum : " << nnn << " bad type : " << nnn.getType() << " " << d_type << std::endl;
-      //}
-      return nnn;
+      return getCurrentTerm( d_ctor );
     } else {
       throw NoMoreValuesException(getType());
     }
   }
 
   DatatypesEnumerator& operator++() throw() {
-    if(d_ctor < d_datatype.getNumConstructors()) {
-      for(size_t a = d_datatype[d_ctor].getNumArgs(); a > 0; --a) {
-        if((++*d_argEnumerators[a - 1]).isFinished()) {
-          *d_argEnumerators[a - 1] = TypeEnumerator(Node::fromExpr(d_datatype[d_ctor][a - 1].getSelector()).getType()[1]);
-        } else {
+    Debug("dt-enum-debug") << ": increment " << this << std::endl;
+    unsigned prevSize = d_size_limit;
+    while(d_ctor < d_datatype.getNumConstructors()) {
+      //increment at index
+      while( increment( d_ctor ) ){
+        Node n = getCurrentTerm( d_ctor );
+        if( !n.isNull() ){
           return *this;
         }
       }
-
       // Here, we need to step from the current constructor to the next one
 
-      // first, delete the current delegate enumerators
-      deleteEnumerators();
-
       // Find the next constructor (only complicated by the notion of the "zero" constructor
       d_ctor = (d_ctor == d_zeroCtor) ? 0 : d_ctor + 1;
       if(d_ctor == d_zeroCtor) {
         ++d_ctor;
       }
-
-      // If we aren't out of constructors, allocate space for the new delegate enumerators
-      if(d_ctor < d_datatype.getNumConstructors()) {
-        newEnumerators();
+      if( d_ctor>=d_datatype.getNumConstructors() ){
+        //try next size limit as long as new terms were generated at last size
+        if( prevSize==d_size_limit ){
+          d_size_limit++;
+          d_ctor = d_zeroCtor;
+          for( unsigned i=0; i<d_sel_sum.size(); i++ ){
+            d_sel_sum[i] = -1;
+          }
+        }
       }
     }
     return *this;