Preparation for models for co-inductive datatypes. Minor cleanup.
[cvc5.git] / src / theory / datatypes / theory_datatypes.cpp
index 9316c3fe87e8a63a27657f6e1c19c1104f89ba3c..2715f8e759e5e9eb805c35266da7cd6d93f92b8e 100644 (file)
@@ -675,10 +675,10 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
                 d_infer.push_back( eq );
                 d_infer_exp.push_back( unifEq );
               }
-            }    
+            }
 /*
             std::vector< Node > rew;
-            if( DatatypesRewriter::checkClash( cons1, cons2, rew ) ){  
+            if( DatatypesRewriter::checkClash( cons1, cons2, rew ) ){
               Assert(false);
             }else{
               for( unsigned i=0; i<rew.size(); i++ ){
@@ -1072,7 +1072,7 @@ void TheoryDatatypes::computeCareGraph(){
                 }
               }
             }
-          } 
+          }
           if (!somePairIsDisequal) {
             for (unsigned c = 0; c < currentPairs.size(); ++ c) {
               Trace("dt-cg-pair") << "Pair : " << currentPairs[c].first << " " << currentPairs[c].second << std::endl;
@@ -1133,6 +1133,7 @@ void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
   eq::EqClassesIterator eqccs_i = eq::EqClassesIterator( &d_equalityEngine );
   std::vector< Node > cons;
   std::vector< Node > nodes;
+  std::map< Node, Node > eqc_cons;
   while( !eqccs_i.isFinished() ){
     Node eqc = (*eqccs_i);
     //for all equivalence classes that are datatypes
@@ -1140,51 +1141,55 @@ void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
       EqcInfo* ei = getOrMakeEqcInfo( eqc );
       if( !ei->d_constructor.get().isNull() ){
         cons.push_back( ei->d_constructor.get() );
+        eqc_cons[ eqc ] = ei->d_constructor.get();
       }else{
         nodes.push_back( eqc );
       }
     }
     ++eqccs_i;
   }
+
   unsigned orig_size = nodes.size();
   unsigned index = 0;
   while( index<nodes.size() ){
     Node eqc = nodes[index];
     Node neqc;
+    const Datatype& dt = ((DatatypeType)(eqc.getType()).toType()).getDatatype();
     if( !d_equalityEngine.hasTerm( eqc ) ){
-      Trace("dt-cmi") << "NOTICE : Datatypes: need to assign constructor for " << eqc << std::endl;
-      Trace("dt-cmi") << "   Type : " << eqc.getType() << std::endl;
-      //can assign arbitrary term
-      TypeEnumerator te(eqc.getType());
-      bool success;
-      do{
-        success = true;
-        Assert( !te.isFinished() );
-        neqc = *te;
-        Trace("dt-cmi-debug") << "Try " << neqc << " ... " << std::endl;
-        ++te;
-        //if it is infinite or we are assigning to terms known by datatypes, make sure it is fresh
-        if( eqc.getType().getCardinality().isInfinite() || index<orig_size ){
-          for( unsigned i=0; i<cons.size(); i++ ){
-            //check if it is modulo equality the same
-            if( cons[i].getOperator()==neqc.getOperator() ){
-              bool diff = false;
-              for( unsigned j=0; j<cons[i].getNumChildren(); j++ ){
-                if( !m->areEqual( cons[i][j], neqc[j] ) ){
-                  diff = true;
+      if( !dt.isCodatatype() ){
+        Trace("dt-cmi") << "NOTICE : Datatypes: need to assign constructor for " << eqc << std::endl;
+        Trace("dt-cmi") << "   Type : " << eqc.getType() << std::endl;
+        //can assign arbitrary term
+        TypeEnumerator te(eqc.getType());
+        bool success;
+        do{
+          success = true;
+          Assert( !te.isFinished() );
+          neqc = *te;
+          Trace("dt-cmi-debug") << "Try " << neqc << " ... " << std::endl;
+          ++te;
+          //if it is infinite or we are assigning to terms known by datatypes, make sure it is fresh
+          if( eqc.getType().getCardinality().isInfinite() || index<orig_size ){
+            for( unsigned i=0; i<cons.size(); i++ ){
+              //check if it is modulo equality the same
+              if( cons[i].getOperator()==neqc.getOperator() ){
+                bool diff = false;
+                for( unsigned j=0; j<cons[i].getNumChildren(); j++ ){
+                  if( !m->areEqual( cons[i][j], neqc[j] ) ){
+                    diff = true;
+                    break;
+                  }
+                }
+                if( !diff ){
+                  Trace("dt-cmi-debug") << "...Already equivalent modulo equality to " << cons[i] << std::endl;
+                  success = false;
                   break;
                 }
               }
-              if( !diff ){
-                Trace("dt-cmi-debug") << "...Already equivalent modulo equality to " << cons[i] << std::endl;
-                success = false;
-                break;
-              }
             }
           }
-        }
-      }while( !success );
+        }while( !success );
+      }
     }else{
       Trace("dt-cmi") << "NOTICE : Datatypes: no constructor in equivalence class " << eqc << std::endl;
       Trace("dt-cmi") << "   Type : " << eqc.getType() << std::endl;
@@ -1212,8 +1217,6 @@ void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
       }
       Trace("dt-cmi") << std::endl;
       */
-
-      const Datatype& dt = ((DatatypeType)(eqc.getType()).toType()).getDatatype();
       for( unsigned r=0; r<2; r++ ){
         if( neqc.isNull() ){
           for( unsigned i=0; i<pcons.size(); i++ ){
@@ -1232,9 +1235,11 @@ void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
         }
       }
     }
-    Assert( !neqc.isNull() );
-    Trace("dt-cmi") << "Assign : " << neqc << std::endl;
-    m->assertEquality( eqc, neqc, true );
+    if( !neqc.isNull() ){
+      Trace("dt-cmi") << "Assign : " << neqc << std::endl;
+      m->assertEquality( eqc, neqc, true );
+      eqc_cons[ eqc ] = neqc;
+    }
     /*
     for( unsigned kk=0; kk<all_eqc.size(); kk++ ){
       for( unsigned ll=(kk+1); ll<all_eqc.size(); ll++ ){
@@ -1258,9 +1263,44 @@ void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
     ++index;
   }
 
+  //assign MU for each codatatype eqc
+  std::map< Node, Node > eqc_mu;
+  for( std::map< Node, Node >::iterator it = eqc_cons.begin(); it != eqc_cons.end(); ++it ){
+    Node eqc = it->first;
+    const Datatype& dt = ((DatatypeType)(eqc.getType()).toType()).getDatatype();
+    if( dt.isCodatatype() ){
+      std::map< Node, Node > vmap;
+      Node v = getCodatatypesValue( it->first, eqc_cons, eqc_mu, vmap );
+      Trace("dt-cmi-cod") << "  EQC(" << it->first << "), constructor is " << it->second << ", value is " << v << std::endl;
+    }
+  }
 }
 
 
+Node TheoryDatatypes::getCodatatypesValue( Node n, std::map< Node, Node >& eqc_cons, std::map< Node, Node >& eqc_mu, std::map< Node, Node >& vmap ){
+  std::map< Node, Node >::iterator itv = vmap.find( n );
+  if( itv!=vmap.end() ){
+    return itv->second;
+  }else if( DatatypesRewriter::isTermDatatype( n ) ){
+    Node nv = NodeManager::currentNM()->mkBoundVar( n.getType() );
+    vmap[n] = nv;
+    Trace("dt-cmi-cod-debug") << "    map " << n << " -> " << nv << std::endl;
+    Node nc = eqc_cons[n];
+    Assert( nc.getKind()==APPLY_CONSTRUCTOR );
+    std::vector< Node > children;
+    children.push_back( nc.getOperator() );
+    for( unsigned i=0; i<nc.getNumChildren(); i++ ){
+      Node r = getRepresentative( nc[i] );
+      Node rv = getCodatatypesValue( r, eqc_cons, eqc_mu, vmap );
+      children.push_back( rv );
+    }
+    vmap.erase( n );
+    return NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, children );
+  }else{
+    return n;
+  }
+}
+
 void TheoryDatatypes::collectTerms( Node n ) {
   if( d_collectTermsCache.find( n )==d_collectTermsCache.end() ){
     d_collectTermsCache[n] = true;
@@ -1802,7 +1842,7 @@ bool TheoryDatatypes::checkClashModEq( Node n1, Node n2, std::vector< Node >& ex
     }
   }else if( n1!=n2 ){
     if( n1.isConst() && n2.isConst() ){
-      return true;        
+      return true;
     }else{
       Node eq = NodeManager::currentNM()->mkNode( n1.getType().isBoolean() ? kind::IFF : kind::EQUAL, n1, n2 );
       if( d_equalityEngine.areDisequal( n1, n2, true ) ){