Add initial support for co-datatypes.
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 14 Apr 2014 09:28:44 +0000 (04:28 -0500)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 14 Apr 2014 09:29:25 +0000 (04:29 -0500)
src/compat/cvc3_compat.cpp
src/parser/cvc/Cvc.g
src/parser/smt2/Smt2.g
src/smt/boolean_terms.cpp
src/theory/builtin/kinds
src/theory/builtin/theory_builtin_type_rules.h
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h
src/util/datatype.h

index 27d1249a179b7a914aeab254dafaaeb5426076b3..427282490979406af15a2fa295aee64f76afb08a 100644 (file)
@@ -1189,7 +1189,7 @@ void ValidityChecker::dataType(const std::vector<std::string>& names,
 
   // Set up the datatype specifications.
   for(unsigned i = 0; i < names.size(); ++i) {
-    CVC4::Datatype dt(names[i]);
+    CVC4::Datatype dt(names[i], false);
     CVC4::CheckArgument(constructors[i].size() == selectors[i].size(), "expected sub-vectors in constructors and selectors vectors to match in size");
     CVC4::CheckArgument(constructors[i].size() == types[i].size(), "expected sub-vectors in constructors and types vectors to match in size");
     for(unsigned j = 0; j < constructors[i].size(); ++j) {
index 792c3cf9d7f7a0c8c47d7ab04b6350a6b44b32b4..f9055f5db9acda6d53f6405c566dbaf85674505d 100644 (file)
@@ -2018,7 +2018,7 @@ datatypeDef[std::vector<CVC4::Datatype>& datatypes]
         params.push_back( t ); }
       )* RBRACKET
     )?
-    { datatypes.push_back(Datatype(id, params));
+    { datatypes.push_back(Datatype(id, params, false));
       if(!PARSER_STATE->isUnresolvedType(id)) {
         // if not unresolved, must be undeclared
         PARSER_STATE->checkDeclaration(id, CHECK_UNDECLARED, SYM_SORT);
index 659fc97d9d4c61289fb1da5cd268e248527cfeda..2118a240defcb9cc4c5502674ea7af0a1fa0d5ef 100644 (file)
@@ -464,17 +464,8 @@ extendedCommand[CVC4::Command*& cmd]
 }
     /* Extended SMT-LIB set of commands syntax, not permitted in
      * --smtlib2 compliance mode. */
-  : DECLARE_DATATYPES_TOK { PARSER_STATE->checkThatLogicIsSet(); }
-    { /* open a scope to keep the UnresolvedTypes contained */
-      PARSER_STATE->pushScope(true); }
-    LPAREN_TOK /* parametric sorts */
-      ( symbol[name,CHECK_UNDECLARED,SYM_SORT] {
-        sorts.push_back( PARSER_STATE->mkSort(name) ); }
-      )*
-    RPAREN_TOK
-    LPAREN_TOK ( LPAREN_TOK datatypeDef[dts, sorts] RPAREN_TOK )+ RPAREN_TOK
-    { PARSER_STATE->popScope();
-      cmd = new DatatypeDeclarationCommand(PARSER_STATE->mkMutualDatatypeTypes(dts)); }
+  : DECLARE_DATATYPES_TOK datatypesDefCommand[false, cmd]
+  | DECLARE_CODATATYPES_TOK datatypesDefCommand[true, cmd]
   | /* get model */
     GET_MODEL_TOK { PARSER_STATE->checkThatLogicIsSet(); }
     { cmd = new GetModelCommand(); }
@@ -603,6 +594,26 @@ extendedCommand[CVC4::Command*& cmd]
     { cmd = new SimplifyCommand(e); }
   ;
 
+
+datatypesDefCommand[bool isCo, CVC4::Command*& cmd]
+@declarations {
+  std::vector<CVC4::Datatype> dts;
+  std::string name;
+  std::vector<Type> sorts;
+}
+  : { PARSER_STATE->checkThatLogicIsSet(); 
+    /* open a scope to keep the UnresolvedTypes contained */
+    PARSER_STATE->pushScope(true); }
+  LPAREN_TOK /* parametric sorts */
+  ( symbol[name,CHECK_UNDECLARED,SYM_SORT] {
+    sorts.push_back( PARSER_STATE->mkSort(name) ); }
+  )*
+  RPAREN_TOK
+  LPAREN_TOK ( LPAREN_TOK datatypeDef[isCo, dts, sorts] RPAREN_TOK )+ RPAREN_TOK
+  { PARSER_STATE->popScope();
+  cmd = new DatatypeDeclarationCommand(PARSER_STATE->mkMutualDatatypeTypes(dts)); }
+  ;
+
 rewriterulesCommand[CVC4::Command*& cmd]
 @declarations {
   std::vector<std::pair<std::string, Type> > sortedVarNames;
@@ -1530,7 +1541,7 @@ nonemptyNumeralList[std::vector<uint64_t>& numerals]
 /**
  * Parses a datatype definition
  */
-datatypeDef[std::vector<CVC4::Datatype>& datatypes, std::vector< CVC4::Type >& params]
+datatypeDef[bool isCo, std::vector<CVC4::Datatype>& datatypes, std::vector< CVC4::Type >& params]
 @init {
   std::string id;
 }
@@ -1548,7 +1559,7 @@ datatypeDef[std::vector<CVC4::Datatype>& datatypes, std::vector< CVC4::Type >& p
         params.push_back( t ); }
       )* ']'
     )?*/ //AJR: this isn't necessary if we use z3's style
-    { datatypes.push_back(Datatype(id,params));
+    { datatypes.push_back(Datatype(id,params,isCo));
       if(!PARSER_STATE->isUnresolvedType(id)) {
         // if not unresolved, must be undeclared
         PARSER_STATE->checkDeclaration(id, CHECK_UNDECLARED, SYM_SORT);
@@ -1623,6 +1634,7 @@ AS_TOK : 'as';
 
 // extended commands
 DECLARE_DATATYPES_TOK : 'declare-datatypes';
+DECLARE_CODATATYPES_TOK : 'declare-codatatypes';
 GET_MODEL_TOK : 'get-model';
 ECHO_TOK : 'echo';
 REWRITE_RULE_TOK : 'assert-rewrite';
index e46a76ed7671dcd8499b36dcd23d9a923d874fed..c779af4ffc99c8ffbe68e6026d3fe38746b62bbe 100644 (file)
@@ -228,9 +228,9 @@ const Datatype& BooleanTermConverter::convertDatatype(const Datatype& dt) throw(
         }
         vector<Datatype> newDtVector;
         if(dt.isParametric()) {
-          newDtVector.push_back(Datatype(dt.getName() + "'", dt.getParameters()));
+          newDtVector.push_back(Datatype(dt.getName() + "'", dt.getParameters(), false));
         } else {
-          newDtVector.push_back(Datatype(dt.getName() + "'"));
+          newDtVector.push_back(Datatype(dt.getName() + "'", false));
         }
         Datatype& newDt = newDtVector.front();
         Debug("boolean-terms") << "found a Boolean arg in constructor " << (*c).getName() << endl;
index b3383e6c42e1b611cc638369a8ad1ade51d88d27..d140d19908e1a63f581ae0047afdc6d684cf2b77 100644 (file)
@@ -302,6 +302,7 @@ variable SKOLEM "skolem var"
 operator SEXPR 0: "a symbolic expression"
 
 operator LAMBDA 2 "lambda"
+operator MU 2 "mu"
 
 parameterized CHAIN CHAIN_OP 2: "chained operator"
 constant CHAIN_OP \
@@ -334,6 +335,7 @@ typerule EQUAL ::CVC4::theory::builtin::EqualityTypeRule
 typerule DISTINCT ::CVC4::theory::builtin::DistinctTypeRule
 typerule SEXPR ::CVC4::theory::builtin::SExprTypeRule
 typerule LAMBDA ::CVC4::theory::builtin::LambdaTypeRule
+typerule MU ::CVC4::theory::builtin::MuTypeRule
 typerule CHAIN ::CVC4::theory::builtin::ChainTypeRule
 typerule CHAIN_OP ::CVC4::theory::builtin::ChainedOperatorTypeRule
 
index c7143bdeb7d40c2e03d608c27fd8c7123bd399c3..f35286f05268d6a01fd10897e27df7def59ebcd0 100644 (file)
@@ -164,6 +164,24 @@ public:
   }
 };/* class LambdaTypeRule */
 
+class MuTypeRule {
+public:
+  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) {
+    if( n[0].getType(check) != nodeManager->boundVarListType() ) {
+      std::stringstream ss;
+      ss << "expected a bound var list for MU expression, got `"
+         << n[0].getType().toString() << "'";
+      throw TypeCheckingExceptionPrivate(n, ss.str());
+    }
+    std::vector<TypeNode> argTypes;
+    for(TNode::iterator i = n[0].begin(); i != n[0].end(); ++i) {
+      argTypes.push_back((*i).getType());
+    }
+    TypeNode rangeType = n[1].getType(check);
+    return nodeManager->mkFunctionType(argTypes, rangeType);
+  }
+};/* class MuTypeRule */
+
 class ChainTypeRule {
 public:
   inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) {
index 0b0f5807c8698fd1fa0756e4bf19933b003e062c..fc37c5417f4a3883a2611869633e6d5464a73186 100644 (file)
@@ -123,13 +123,19 @@ void TheoryDatatypes::check(Effort e) {
 
   if( e == EFFORT_FULL ) {
     //check for cycles
-    checkCycles();
-    if( d_conflict ){
-      return;
-    }
+    bool addedFact;
+    do {
+      checkCycles();
+      addedFact = !d_pending.empty() || !d_pending_merge.empty();
+      flushPendingFacts();
+      if( d_conflict ){
+        return;
+      }
+    }while( addedFact );
+
     //check for splits
     Debug("datatypes-split") << "Check for splits " << e << endl;
-    bool addedFact = false;
+    addedFact = false;
     do {
       eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine );
       while( !eqcs_i.isFinished() ){
@@ -146,7 +152,7 @@ void TheoryDatatypes::check(Effort e) {
               Node t = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[0].getTester() ), n );
               d_pending.push_back( t );
               d_pending_exp[ t ] = NodeManager::currentNM()->mkConst( true );
-              Trace("datatypes-infer") << "DtInfer : " << t << ", trivial" << std::endl;
+              Trace("datatypes-infer") << "DtInfer : 1-cons : " << t << std::endl;
               d_infer.push_back( t );
             }else{
               std::vector< bool > pcons;
@@ -493,21 +499,37 @@ bool TheoryDatatypes::propagate(TNode literal){
   return ok;
 }
 
+void TheoryDatatypes::addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions ) {
+  for( unsigned i=0; i<tassumptions.size(); i++ ){
+    if( std::find( assumptions.begin(), assumptions.end(), tassumptions[i] )==assumptions.end() ){
+      assumptions.push_back( tassumptions[i] );
+    }
+  }
+}
+
+void TheoryDatatypes::explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions ) {
+  if( a!=b ){
+    std::vector<TNode> tassumptions;
+    d_equalityEngine.explainEquality(a, b, polarity, tassumptions);
+    addAssumptions( assumptions, tassumptions );
+  }
+}
+
+void TheoryDatatypes::explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions ) {
+  std::vector<TNode> tassumptions;
+  d_equalityEngine.explainPredicate(p, polarity, tassumptions);
+  addAssumptions( assumptions, tassumptions );
+}
+
 /** explain */
 void TheoryDatatypes::explain(TNode literal, std::vector<TNode>& assumptions){
   Debug("datatypes-explain") << "Explain " << literal << std::endl;
   bool polarity = literal.getKind() != kind::NOT;
   TNode atom = polarity ? literal : literal[0];
-  std::vector<TNode> tassumptions;
   if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) {
-    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, tassumptions);
+    explainEquality( atom[0], atom[1], polarity, assumptions );
   } else {
-    d_equalityEngine.explainPredicate(atom, polarity, tassumptions);
-  }
-  for( unsigned i=0; i<tassumptions.size(); i++ ){
-    if( std::find( assumptions.begin(), assumptions.end(), tassumptions[i] )==assumptions.end() ){
-      assumptions.push_back( tassumptions[i] );
-    }
+    explainPredicate( atom, polarity, assumptions );
   }
 }
 
@@ -585,7 +607,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
                 Node eq = cons1[i].getType().isBoolean() ? cons1[i].iffNode( cons2[i] ) : cons1[i].eqNode( cons2[i] );
                 d_pending.push_back( eq );
                 d_pending_exp[ eq ] = unifEq;
-                Trace("datatypes-infer") << "DtInfer : " << eq << " by " << unifEq << std::endl;
+                Trace("datatypes-infer") << "DtInfer : cons-inj : " << eq << " by " << unifEq << std::endl;
                 d_infer.push_back( eq );
                 d_infer_exp.push_back( unifEq );
               }
@@ -737,7 +759,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){
         //conflict because equivalence class contains a constructor
         std::vector< TNode > assumptions;
         explain( t, assumptions );
-        explain( eqc->d_constructor.get().eqNode( tt[0] ), assumptions );
+        explainEquality( eqc->d_constructor.get(), tt[0], true, assumptions );
         d_conflictNode = mkAnd( assumptions );
         Trace("dt-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl;
         d_out->conflict( d_conflictNode );
@@ -808,7 +830,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){
           Node t_concl_exp = ( nb.getNumChildren() == 1 ) ? nb.getChild( 0 ) : nb;
           d_pending.push_back( t_concl );
           d_pending_exp[ t_concl ] = t_concl_exp;
-          Trace("datatypes-infer") << "DtInfer : " << t_concl << " by " << t_concl_exp << std::endl;
+          Trace("datatypes-infer") << "DtInfer : label : " << t_concl << " by " << t_concl_exp << std::endl;
           d_infer.push_back( t_concl );
           d_infer_exp.push_back( t_concl_exp );
           return;
@@ -822,7 +844,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){
     std::vector< TNode > assumptions;
     explain( j, assumptions );
     explain( t, assumptions );
-    explain( jt[0].eqNode( tt[0] ), assumptions );
+    explainEquality( jt[0], tt[0], true, assumptions );
     d_conflictNode = mkAnd( assumptions );
     Trace("dt-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl;
     d_out->conflict( d_conflictNode );
@@ -866,7 +888,7 @@ void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){
           Node n = *i;
           std::vector< TNode > assumptions;
           explain( *i, assumptions );
-          explain( c.eqNode( (*i)[0][0] ), assumptions );
+          explainEquality( c, (*i)[0][0], true, assumptions );
           d_conflictNode = mkAnd( assumptions );
           Trace("dt-conflict") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl;
           d_out->conflict( d_conflictNode );
@@ -901,7 +923,7 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) {
 
     d_pending.push_back( eq );
     d_pending_exp[ eq ] = eq_exp;
-    Trace("datatypes-infer") << "DtInfer : " << eq << " by " << eq_exp << " (collapse selector)" << std::endl;
+    Trace("datatypes-infer") << "DtInfer : collapse sel : " << eq << " by " << eq_exp << std::endl;
     d_infer.push_back( eq );
     d_infer_exp.push_back( eq_exp );
   }
@@ -1141,7 +1163,7 @@ void TheoryDatatypes::processNewTerm( Node n ){
     Node eq = n.getType().isBoolean() ? rn.iffNode( n ) : rn.eqNode( n );
     d_pending.push_back( eq );
     d_pending_exp[ eq ] = NodeManager::currentNM()->mkConst( true );
-    Trace("datatypes-infer") << "DtInfer : " << eq << ", trivial" << std::endl;
+    Trace("datatypes-infer") << "DtInfer : rewrite : " << eq << std::endl;
     d_infer.push_back( eq );
   }
 }
@@ -1219,7 +1241,7 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
       Debug("datatypes-inst") << "DtInstantiate : " << eqc << " " << eq << std::endl;
       d_pending.push_back( eq );
       d_pending_exp[ eq ] = exp;
-      Trace("datatypes-infer") << "DtInfer : " << eq << " by " << exp << std::endl;
+      Trace("datatypes-infer") << "DtInfer : instantiate : " << eq << " by " << exp << std::endl;
       //eqc->d_inst.set( eq );
       d_infer.push_back( eq );
       d_infer_exp.push_back( exp );
@@ -1233,33 +1255,201 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
 
 void TheoryDatatypes::checkCycles() {
   Debug("datatypes-cycle-check") << "Check cycles" << std::endl;
+  std::vector< Node > cod_eqc;
   eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine );
   while( !eqcs_i.isFinished() ){
     Node eqc = (*eqcs_i);
     if( DatatypesRewriter::isTermDatatype( eqc ) ) {
-      map< Node, bool > visited;
-      std::vector< TNode > expl;
-      Node cn = searchForCycle( eqc, eqc, visited, expl );
-      //if we discovered a different cycle while searching this one
-      if( !cn.isNull() && cn!=eqc ){
-        visited.clear();
-        expl.clear();
-        Node prev = cn;
-        cn = searchForCycle( cn, cn, visited, expl );
-        Assert( prev==cn );
-      }
+      const Datatype& dt = ((DatatypeType)(eqc.getType()).toType()).getDatatype();
+      if( !dt.isCodatatype() ){
+        //do cycle checks
+        map< Node, bool > visited;
+        std::vector< TNode > expl;
+        Node cn = searchForCycle( eqc, eqc, visited, expl );
+        //if we discovered a different cycle while searching this one
+        if( !cn.isNull() && cn!=eqc ){
+          visited.clear();
+          expl.clear();
+          Node prev = cn;
+          cn = searchForCycle( cn, cn, visited, expl );
+          Assert( prev==cn );
+        }
 
-      if( !cn.isNull() ) {
-        Assert( expl.size()>0 );
-        d_conflictNode = mkAnd( expl );
-        Trace("dt-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl;
-        d_out->conflict( d_conflictNode );
-        d_conflict = true;
-        return;
+        if( !cn.isNull() ) {
+          Assert( expl.size()>0 );
+          d_conflictNode = mkAnd( expl );
+          Trace("dt-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl;
+          d_out->conflict( d_conflictNode );
+          d_conflict = true;
+          return;
+        }
+      }else{
+        //indexing
+        cod_eqc.push_back( eqc );
       }
     }
     ++eqcs_i;
   }
+  //process codatatypes
+  if( cod_eqc.size()>1 ){
+    std::vector< std::vector< Node > > part_out;
+    std::vector< TNode > exp;
+    std::map< Node, Node > cn;
+    std::map< Node, std::map< Node, int > > dni;
+    for( unsigned i=0; i<cod_eqc.size(); i++ ){
+      cn[cod_eqc[i]] = cod_eqc[i];
+    }
+    separateBisimilar( cod_eqc, part_out, exp, cn, dni, 0, false );
+    if( !part_out.empty() ){
+      for( unsigned i=0; i<part_out.size(); i++ ){
+        std::vector< Node > part;
+        part.push_back( part_out[i][0] );
+        for( unsigned j=1; j<part_out[i].size(); j++ ){
+          Trace("dt-cod") << "Codatatypes : " << part_out[i][0] << " and " << part_out[i][j] << " must be equal!!" << std::endl;
+          part.push_back( part_out[i][j] );
+          std::vector< std::vector< Node > > tpart_out;
+          exp.clear();
+          cn.clear();
+          cn[part_out[i][0]] = part_out[i][0];
+          cn[part_out[i][j]] = part_out[i][j];
+          dni.clear();
+          separateBisimilar( part, tpart_out, exp, cn, dni, 0, true );
+          Assert( tpart_out.size()==1 && tpart_out[0].size()==2 );
+          part.pop_back();
+          //merge based on explanation
+          Trace("dt-cod") << "  exp is : ";
+          for( unsigned k=0; k<exp.size(); k++ ){
+            Trace("dt-cod") << exp[k] << " ";
+          }
+          Trace("dt-cod") << std::endl;
+          Node eq = part_out[i][0].eqNode( part_out[i][j] );
+          Node eqExp = mkAnd( exp );
+          d_pending.push_back( eq );
+          d_pending_exp[ eq ] = eqExp;
+          Trace("datatypes-infer") << "DtInfer : cod-bisimilar : " << eq << " by " << eqExp << std::endl;
+          d_infer.push_back( eq );
+          d_infer_exp.push_back( eqExp );
+        }
+      }
+    }
+  }
+}
+
+//everything is in terms of representatives
+void TheoryDatatypes::separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out,
+                                         std::vector< TNode >& exp,
+                                         std::map< Node, Node >& cn,
+                                         std::map< Node, std::map< Node, int > >& dni, int dniLvl, bool mkExp ){
+  if( !mkExp ){
+    Trace("dt-cod-debug") << "Separate bisimilar : " << std::endl;
+    for( unsigned i=0; i<part.size(); i++ ){
+      Trace("dt-cod-debug") << "   " << part[i] << ", current = " << cn[part[i]] << std::endl;
+    }
+  }
+  Assert( part.size()>1 );
+  std::map< Node, std::vector< Node > > new_part;
+  std::map< Node, std::vector< Node > > new_part_c;
+  std::map< int, std::vector< Node > > new_part_rec;
+
+  std::map< Node, Node > cn_cons;
+  for( unsigned j=0; j<part.size(); j++ ){
+    Node c = cn[part[j]];
+    std::map< Node, int >::iterator it_rec = dni[part[j]].find( c );
+    if( it_rec!=dni[part[j]].end() ){
+      //looped
+      if( !mkExp ){ Trace("dt-cod-debug") << "  - " << part[j] << " is looping at index " << it_rec->second << std::endl; }
+      new_part_rec[ it_rec->second ].push_back( part[j] );
+    }else{
+      if( DatatypesRewriter::isTermDatatype( c ) ){
+        bool foundCons = false;
+        EqcInfo* eqc = getOrMakeEqcInfo( c, false );
+        if( eqc ){
+          Node ncons = eqc->d_constructor.get();
+          if( !ncons.isNull() ) {
+            foundCons = true;
+            Node cc = ncons.getOperator();
+            cn_cons[part[j]] = ncons;
+            if( mkExp ){
+              explainEquality( c, ncons, true, exp );
+            }
+            new_part[cc].push_back( part[j] );
+            if( !mkExp ){ Trace("dt-cod-debug") << "  - " << part[j] << " is datatype " << ncons << "." << std::endl; }
+          }
+        }
+        if( !foundCons ){
+          new_part_c[c].push_back( part[j] );
+          if( !mkExp ){ Trace("dt-cod-debug") << "  - " << part[j] << " is unspecified datatype." << std::endl; }
+        }
+      }else{
+        //add equivalences
+        if( !mkExp ){ Trace("dt-cod-debug") << "  - " << part[j] << " is term " << c << "." << std::endl; }
+        new_part_c[c].push_back( part[j] );
+      }
+    }
+  }
+  //direct add for constants
+  for( std::map< Node, std::vector< Node > >::iterator it = new_part_c.begin(); it != new_part_c.end(); ++it ){
+    if( it->second.size()>1 ){
+      std::vector< Node > vec;
+      vec.insert( vec.begin(), it->second.begin(), it->second.end() );
+      part_out.push_back( vec );
+    }
+  }
+  //direct add for recursive
+  for( std::map< int, std::vector< Node > >::iterator it = new_part_rec.begin(); it != new_part_rec.end(); ++it ){
+    if( it->second.size()>1 ){
+      std::vector< Node > vec;
+      vec.insert( vec.begin(), it->second.begin(), it->second.end() );
+      part_out.push_back( vec );
+    }else{
+      //add back : could match a datatype?
+    }
+  }
+  //recurse for the datatypes
+  for( std::map< Node, std::vector< Node > >::iterator it = new_part.begin(); it != new_part.end(); ++it ){
+    if( it->second.size()>1 ){
+      //set dni to check for loops
+      std::map< Node, Node > dni_rem;
+      for( unsigned i=0; i<it->second.size(); i++ ){
+        Node n = it->second[i];
+        dni[n][cn[n]] = dniLvl;
+        dni_rem[n] = cn[n];
+      }
+
+      //we will split based on the arguments of the datatype
+      std::vector< std::vector< Node > > split_new_part;
+      split_new_part.push_back( it->second );
+
+      unsigned nChildren = cn_cons[it->second[0]].getNumChildren();
+      //for each child of constructor
+      unsigned cindex = 0;
+      while( cindex<nChildren && !split_new_part.empty() ){
+        if( !mkExp ){ Trace("dt-cod-debug") << "Split argument #" << cindex << " of " << it->first << "..." << std::endl; }
+        std::vector< std::vector< Node > > next_split_new_part;
+        for( unsigned j=0; j<split_new_part.size(); j++ ){
+          //set current node
+          for( unsigned k=0; k<split_new_part[j].size(); k++ ){
+            Node n = split_new_part[j][k];
+            cn[n] = getRepresentative( cn_cons[n][cindex] );
+            if( mkExp ){
+              explainEquality( cn[n], cn_cons[n][cindex], true, exp );
+            }
+          }
+          std::vector< std::vector< Node > > c_part_out;
+          separateBisimilar( split_new_part[j], c_part_out, exp, cn, dni, dniLvl+1, mkExp );
+          next_split_new_part.insert( next_split_new_part.end(), c_part_out.begin(), c_part_out.end() );
+        }
+        split_new_part.clear();
+        split_new_part.insert( split_new_part.end(), next_split_new_part.begin(), next_split_new_part.end() );
+        cindex++;
+      }
+      part_out.insert( part_out.end(), split_new_part.begin(), split_new_part.end() );
+
+      for( std::map< Node, Node >::iterator it2 = dni_rem.begin(); it2 != dni_rem.end(); ++it2 ){
+        dni[it2->first].erase( it2->second );
+      }
+    }
+  }
 }
 
 //postcondition: if cycle detected, explanation is why n is a subterm of on
@@ -1272,8 +1462,7 @@ Node TheoryDatatypes::searchForCycle( Node n, Node on,
   if( !firstTime ){
     nn = getRepresentative( n );
     if( nn==on ){
-      Node lit = NodeManager::currentNM()->mkNode( EQUAL, n, nn );
-      explain( lit, explanation );
+      explainEquality( n, nn, true, explanation );
       return on;
     }
   }else{
@@ -1293,8 +1482,7 @@ Node TheoryDatatypes::searchForCycle( Node n, Node on,
             //}
             //add explanation for why the constructor is connected
             if( n != ncons ) {
-              Node lit = NodeManager::currentNM()->mkNode( EQUAL, n, ncons );
-              explain( lit, explanation );
+              explainEquality( n, ncons, true, explanation );
             }
             return on;
           }else if( !cn.isNull() ){
@@ -1306,7 +1494,13 @@ Node TheoryDatatypes::searchForCycle( Node n, Node on,
     visited.erase( nn );
     return Node::null();
   }else{
-    return nn;
+    if( DatatypesRewriter::isTermDatatype( nn ) ) {
+      const Datatype& dt = ((DatatypeType)(nn.getType()).toType()).getDatatype();
+      if( !dt.isCodatatype() ){
+        return nn;
+      }
+    }
+    return Node::null();
   }
 }
 
index 2a93878d0b0edccff1770fe6c4a83fcf60c39088..eb8d792cf94745612cf7c26bbefbe19b30681ed3 100644 (file)
@@ -195,6 +195,9 @@ public:
   /** propagate */
   bool propagate(TNode literal);
   /** explain */
+  void addAssumptions( std::vector<TNode>& assumptions, std::vector<TNode>& tassumptions );
+  void explainEquality( TNode a, TNode b, bool polarity, std::vector<TNode>& assumptions );
+  void explainPredicate( TNode p, bool polarity, std::vector<TNode>& assumptions );
   void explain( TNode literal, std::vector<TNode>& assumptions );
   Node explain( TNode literal );
   /** Conflict when merging two constants */
@@ -236,6 +239,11 @@ private:
   Node searchForCycle( Node n, Node on,
                        std::map< Node, bool >& visited,
                        std::vector< TNode >& explanation, bool firstTime = true );
+  /** for checking whether two codatatype terms must be equal */
+  void separateBisimilar( std::vector< Node >& part, std::vector< std::vector< Node > >& part_out,
+                          std::vector< TNode >& exp,
+                          std::map< Node, Node >& cn,
+                          std::map< Node, std::map< Node, int > >& dni, int dniLvl, bool mkExp );
   /** collect terms */
   void collectTerms( Node n );
   /** get instantiate cons */
index 02e0b6be8636ac3c22f90ef5066d316eee48e95f..a8f3b404ac622128aa6611744798d645efcbe5f4 100644 (file)
@@ -448,6 +448,7 @@ public:
 private:
   std::string d_name;
   std::vector<Type> d_params;
+  bool d_isCo;
   std::vector<DatatypeConstructor> d_constructors;
   bool d_resolved;
   Type d_self;
@@ -494,13 +495,13 @@ private:
 public:
 
   /** Create a new Datatype of the given name. */
-  inline explicit Datatype(std::string name);
+  inline explicit Datatype(std::string name, bool isCo = false);
 
   /**
    * Create a new Datatype of the given name, with the given
    * parameterization.
    */
-  inline Datatype(std::string name, const std::vector<Type>& params);
+  inline Datatype(std::string name, const std::vector<Type>& params, bool isCo = false);
 
   /**
    * Add a constructor to this Datatype.  Constructor names need not
@@ -526,6 +527,9 @@ public:
   /** Get parameters */
   inline std::vector<Type> getParameters() const;
 
+  /** is this a co-datatype? */
+  inline bool isCodatatype() const;
+
   /**
    * Return the cardinality of this datatype (the sum of the
    * cardinalities of its constructors).  The Datatype must be
@@ -662,18 +666,20 @@ inline std::string DatatypeUnresolvedType::getName() const throw() {
   return d_name;
 }
 
-inline Datatype::Datatype(std::string name) :
+inline Datatype::Datatype(std::string name, bool isCo) :
   d_name(name),
   d_params(),
+  d_isCo(isCo),
   d_constructors(),
   d_resolved(false),
   d_self(),
   d_card(CardinalityUnknown()) {
 }
 
-inline Datatype::Datatype(std::string name, const std::vector<Type>& params) :
+inline Datatype::Datatype(std::string name, const std::vector<Type>& params, bool isCo) :
   d_name(name),
   d_params(params),
+  d_isCo(isCo),
   d_constructors(),
   d_resolved(false),
   d_self(),
@@ -707,6 +713,10 @@ inline std::vector<Type> Datatype::getParameters() const {
   return d_params;
 }
 
+inline bool Datatype::isCodatatype() const {
+  return d_isCo;
+}
+
 inline bool Datatype::operator!=(const Datatype& other) const throw() {
   return !(*this == other);
 }