Accelerate sygus solution reconstruction for constants and id functions. Minor chang...
authorajreynol <andrew.j.reynolds@gmail.com>
Fri, 12 Jun 2015 12:15:14 +0000 (14:15 +0200)
committerajreynol <andrew.j.reynolds@gmail.com>
Fri, 12 Jun 2015 12:15:14 +0000 (14:15 +0200)
src/parser/smt2/smt2.cpp
src/theory/quantifiers/ce_guided_single_inv_sol.cpp
src/theory/quantifiers/term_database.cpp
src/theory/quantifiers/term_database.h
src/util/datatype.cpp
src/util/datatype.h

index 47022da3ea0dc18937a605cf80b4610b04f37871..8ed8e40a181533a8b0bd8c6a845526d33b5780c1 100644 (file)
@@ -1001,17 +1001,22 @@ void Smt2::mkSygusDatatype( CVC4::Datatype& dt, std::vector<CVC4::Expr>& ops,
 
 void Smt2::addSygusDatatypeConstructor( CVC4::Datatype& dt, CVC4::Expr op, std::string& cname, std::vector< CVC4::Type >& cargs,
                                         CVC4::Expr& let_body, std::vector< CVC4::Expr >& let_args, unsigned let_num_input_args ) {
-  
+  Debug("parser-sygus") << "--> Add constructor " << cname << " to " << dt.getName() << std::endl;
+  if( !let_body.isNull() ){
+    Debug("parser-sygus") << "    let body = " << let_body << ", args = " << let_args.size() << "," << let_num_input_args << std::endl;
+    //TODO : remove arguments not occurring in body
+    //if this is a self identity function, ignore
+    if( let_args.size()==0 && let_args[0]==let_body ){
+      Debug("parser-sygus") << "    identity function " << cargs[0] << " to " << dt.getName() << std::endl;
+      //TODO
+    }
+  }
   std::string name = dt.getName() + "_" + cname;
   std::string testerId("is-");
   testerId.append(name);
   checkDeclaration(name, CHECK_UNDECLARED, SYM_VARIABLE);
   checkDeclaration(testerId, CHECK_UNDECLARED, SYM_VARIABLE);
   CVC4::DatatypeConstructor c(name, testerId );
-  Debug("parser-sygus") << "--> Add constructor " << cname << " to " << dt.getName() << std::endl;
-  if( !let_body.isNull() ){
-    Debug("parser-sygus") << "    let body = " << let_body << ", args = " << let_args.size() << "," << let_num_input_args << std::endl;
-  }
   c.setSygus( op, let_body, let_args, let_num_input_args );
   for( unsigned j=0; j<cargs.size(); j++ ){
     std::stringstream sname;
index 7cefb0aec67f39c9c94dea6f8aed1b511cbbec20..845e20795e2e82f4c24ba85c66c4221c68e3dd7b 100644 (file)
@@ -717,7 +717,7 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in
     //check if op is in syntax sort
     carg = d_qe->getTermDatabaseSygus()->getOpArg( stn, min_t );
     if( carg!=-1 ){
-      Trace("csi-rcons-debug") << "  Type has constant." << std::endl;
+      Trace("csi-rcons-debug") << "  Type has operator." << std::endl;
       d_reconstruct[id] = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) );
       status = 0;
     }else{
@@ -754,117 +754,140 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in
           Trace("csi-rcons-debug") << "Type for " << id << " has kind " << min_t.getKind() << ", but argument # mismatch." << std::endl;
         }
       }
-    }
-    if( status!=0 ){
-      if( min_t.isConst() ){
-        Node min_t_c = d_qe->getTermDatabaseSygus()->builtinToSygusConst( min_t, stn );
-        if( !min_t_c.isNull() ){
-          d_reconstruct[id] = min_t_c;
-          status = 0;
-        }
-      }
       if( status!=0 ){
-        Trace("csi-rcons-debug") << "Try matching for " << id << "." << std::endl;
-        //try other options
-        //match against other constructors
-        bool success;
-        int c_index = 0;
-        do{
-          success = false;
-          int index_found;
-          std::vector< Node > args;
-          if( d_qe->getTermDatabaseSygus()->getMatch( min_t, stn, index_found, args, karg, c_index ) ){
-            success = true;
+        //try constant reconstruction
+        if( min_t.isConst() ){
+          Trace("csi-rcons-debug") << "...try constant reconstruction." << std::endl;
+          Node min_t_c = d_qe->getTermDatabaseSygus()->builtinToSygusConst( min_t, stn );
+          if( !min_t_c.isNull() ){
+            Trace("csi-rcons-debug") << "   constant reconstruction success for " << id << ", result = " << min_t_c << std::endl;
+            d_reconstruct[id] = min_t_c;
             status = 0;
-            Node cons = Node::fromExpr( dt[index_found].getConstructor() );
-            Trace("csi-rcons-debug") << "Try alternative for " << id << ", matching " << dt[index_found].getName() << " with children : " << std::endl;
-            for( unsigned i=0; i<args.size(); i++ ){
-              Trace("csi-rcons-debug") << "  " << args[i] << std::endl;
-            }
-            if( !collectReconstructNodes( id, args, dt[index_found], d_reconstruct_op[id][cons], status ) ){
+          }
+        }
+        if( status!=0 ){
+          //try identity functions
+          for( unsigned i=0; i<d_qe->getTermDatabaseSygus()->getNumIdFuncs( stn ); i++ ){
+            unsigned ii = d_qe->getTermDatabaseSygus()->getIdFuncIndex( stn, i );
+            Assert( dt[ii].getNumArgs()==1 );
+            //try to directly reconstruct from single argument
+            std::vector< Node > tchildren;
+            tchildren.push_back( min_t );
+            TypeNode stnc = TypeNode::fromType( ((SelectorType)dt[ii][0].getType()).getRangeType() );
+            Trace("csi-rcons-debug") << "...try identity function " << dt[ii].getSygusOp() << ", child type is " << stnc << std::endl;
+            status = 0;
+            Node cons = Node::fromExpr( dt[ii].getConstructor() );
+            if( !collectReconstructNodes( id, tchildren, dt[ii], d_reconstruct_op[id][cons], status ) ){
               d_reconstruct_op[id].erase( cons );
               status = 1;
             }else{
-              c_index = index_found+1;
+              Trace("csi-rcons-debug") << "   identity function success for " << id << std::endl;
+              break;
             }
           }
-        }while( success && status!=0 );
-
-        if( status!=0 ){
-          // construct an equivalence class of terms that are equivalent to t
-          if( d_rep[id]==id ){
-            Trace("csi-rcons-debug") << "Try rewriting for " << id << "." << std::endl;
-            //get equivalence class of term
-            std::vector< Node > equiv;
-            if( tn.isBoolean() ){
-              Node curr = min_t;
-              Node new_t;
-              do{
-                new_t = Node::null();
-                if( curr.getKind()==EQUAL && ( curr[0].getType().isInteger() || curr[0].getType().isReal() ) ){
-                  new_t = NodeManager::currentNM()->mkNode( AND, NodeManager::currentNM()->mkNode( LEQ, curr[0], curr[1] ),
-                                                                NodeManager::currentNM()->mkNode( LEQ, curr[1], curr[0] ) );
-                }else if( curr.getKind()==ITE ){
-                  new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ),
-                                                                NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[2] ) );
-                }else if( curr.getKind()==IFF ){
-                  new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ),
-                                                                NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[1].negate() ) );
-                }else if( curr.getKind()==OR || curr.getKind()==AND ){
-                  new_t = TermDb::simpleNegate( curr ).negate();
-                }else if( curr.getKind()==NOT ){
-                  new_t = TermDb::simpleNegate( curr[0] );
+          if( status!=0 ){
+            //try other options, such as matching against other constructors
+            Trace("csi-rcons-debug") << "Try matching for " << id << "." << std::endl;
+            bool success;
+            int c_index = 0;
+            do{
+              success = false;
+              int index_found;
+              std::vector< Node > args;
+              if( d_qe->getTermDatabaseSygus()->getMatch( min_t, stn, index_found, args, karg, c_index ) ){
+                success = true;
+                status = 0;
+                Node cons = Node::fromExpr( dt[index_found].getConstructor() );
+                Trace("csi-rcons-debug") << "Try alternative for " << id << ", matching " << dt[index_found].getName() << " with children : " << std::endl;
+                for( unsigned i=0; i<args.size(); i++ ){
+                  Trace("csi-rcons-debug") << "  " << args[i] << std::endl;
+                }
+                if( !collectReconstructNodes( id, args, dt[index_found], d_reconstruct_op[id][cons], status ) ){
+                  d_reconstruct_op[id].erase( cons );
+                  status = 1;
                 }else{
-                  new_t = NodeManager::currentNM()->mkNode( NOT, NodeManager::currentNM()->mkNode( NOT, curr ) );
+                  c_index = index_found+1;
                 }
-                if( !new_t.isNull() ){
-                  if( new_t!=min_t && std::find( equiv.begin(), equiv.end(), new_t )==equiv.end() ){
-                    curr = new_t;
-                    equiv.push_back( new_t );
-                  }else{
+              }
+            }while( success && status!=0 );
+
+            if( status!=0 ){
+              // construct an equivalence class of terms that are equivalent to t
+              if( d_rep[id]==id ){
+                Trace("csi-rcons-debug") << "Try rewriting for " << id << "." << std::endl;
+                //get equivalence class of term
+                std::vector< Node > equiv;
+                if( tn.isBoolean() ){
+                  Node curr = min_t;
+                  Node new_t;
+                  do{
                     new_t = Node::null();
+                    if( curr.getKind()==EQUAL && ( curr[0].getType().isInteger() || curr[0].getType().isReal() ) ){
+                      new_t = NodeManager::currentNM()->mkNode( AND, NodeManager::currentNM()->mkNode( LEQ, curr[0], curr[1] ),
+                                                                    NodeManager::currentNM()->mkNode( LEQ, curr[1], curr[0] ) );
+                    }else if( curr.getKind()==ITE ){
+                      new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ),
+                                                                    NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[2] ) );
+                    }else if( curr.getKind()==IFF ){
+                      new_t = NodeManager::currentNM()->mkNode( OR, NodeManager::currentNM()->mkNode( AND, curr[0], curr[1] ),
+                                                                    NodeManager::currentNM()->mkNode( AND, curr[0].negate(), curr[1].negate() ) );
+                    }else if( curr.getKind()==OR || curr.getKind()==AND ){
+                      new_t = TermDb::simpleNegate( curr ).negate();
+                    }else if( curr.getKind()==NOT ){
+                      new_t = TermDb::simpleNegate( curr[0] );
+                    }else{
+                      new_t = NodeManager::currentNM()->mkNode( NOT, NodeManager::currentNM()->mkNode( NOT, curr ) );
+                    }
+                    if( !new_t.isNull() ){
+                      if( new_t!=min_t && std::find( equiv.begin(), equiv.end(), new_t )==equiv.end() ){
+                        curr = new_t;
+                        equiv.push_back( new_t );
+                      }else{
+                        new_t = Node::null();
+                      }
+                    }
+                  }while( !new_t.isNull() );
+                }
+                for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
+                  Kind k = d_qe->getTermDatabaseSygus()->getArgKind( stn, i );
+                  if( k==AND || k==OR ){
+                    equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, min_t ) );
+                    equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, NodeManager::currentNM()->mkConst( k==AND ) ) );
                   }
                 }
-              }while( !new_t.isNull() );
-            }
-            for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
-              Kind k = d_qe->getTermDatabaseSygus()->getArgKind( stn, i );
-              if( k==AND || k==OR ){
-                equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, min_t ) );
-                equiv.push_back( NodeManager::currentNM()->mkNode( k, min_t, NodeManager::currentNM()->mkConst( k==AND ) ) );
-              }
-            }
-            //assign ids to terms
-            Trace("csi-rcons-debug") << "Term " << id << " is equivalent to " << equiv.size() << " terms : " << std::endl;
-            std::vector< int > equiv_ids;
-            for( unsigned i=0; i<equiv.size(); i++ ){
-              Trace("csi-rcons-debug") << "  " << equiv[i] << std::endl;
-              if( d_rcons_to_id[stn].find( equiv[i] )==d_rcons_to_id[stn].end() ){
-                int eq_id = allocate( equiv[i], stn );
-                d_eqc.erase( eq_id );
-                d_rep[eq_id] = id;
-                d_eqc[id].push_back( eq_id );
-                equiv_ids.push_back( eq_id );
-              }else{
-                equiv_ids.push_back( -1 );
-              }
-            }
-            // now, try each of them
-            for( unsigned i=0; i<equiv.size(); i++ ){
-              if( equiv_ids[i]!=-1 ){
-                collectReconstructNodes( equiv[i], stn, status );
-                //if one succeeds
-                if( status==0 ){
-                  Node rsol = getReconstructedSolution( equiv_ids[i] );
-                  Assert( !rsol.isNull() );
-                  //set all members of the equivalence class that this is the reconstructed solution
-                  setReconstructed( id, rsol );
-                  break;
+                //assign ids to terms
+                Trace("csi-rcons-debug") << "Term " << id << " is equivalent to " << equiv.size() << " terms : " << std::endl;
+                std::vector< int > equiv_ids;
+                for( unsigned i=0; i<equiv.size(); i++ ){
+                  Trace("csi-rcons-debug") << "  " << equiv[i] << std::endl;
+                  if( d_rcons_to_id[stn].find( equiv[i] )==d_rcons_to_id[stn].end() ){
+                    int eq_id = allocate( equiv[i], stn );
+                    d_eqc.erase( eq_id );
+                    d_rep[eq_id] = id;
+                    d_eqc[id].push_back( eq_id );
+                    equiv_ids.push_back( eq_id );
+                  }else{
+                    equiv_ids.push_back( -1 );
+                  }
                 }
+                // now, try each of them
+                for( unsigned i=0; i<equiv.size(); i++ ){
+                  if( equiv_ids[i]!=-1 ){
+                    collectReconstructNodes( equiv[i], stn, status );
+                    //if one succeeds
+                    if( status==0 ){
+                      Node rsol = getReconstructedSolution( equiv_ids[i] );
+                      Assert( !rsol.isNull() );
+                      //set all members of the equivalence class that this is the reconstructed solution
+                      setReconstructed( id, rsol );
+                      break;
+                    }
+                  }
+                }
+              }else{
+                Trace("csi-rcons-debug") << "Do not try rewriting for " << id << ", rep = " << d_rep[id] << std::endl;
               }
             }
-          }else{
-            Trace("csi-rcons-debug") << "Do not try rewriting for " << id << ", rep = " << d_rep[id] << std::endl;
           }
         }
       }
@@ -880,6 +903,7 @@ int CegConjectureSingleInvSol::collectReconstructNodes( Node t, TypeNode stn, in
 }
 
 bool CegConjectureSingleInvSol::collectReconstructNodes( int pid, std::vector< Node >& ts, const DatatypeConstructor& dtc, std::vector< int >& ids, int& status ) {
+  Assert( dtc.getNumArgs()==ts.size() );
   for( unsigned i=0; i<ts.size(); i++ ){
     TypeNode cstn = d_qe->getTermDatabaseSygus()->getArgType( dtc, i );
     int cstatus;
index 60573a7fc3c495a9be061b914793ef40e880493f..646a1565e3a4f7757e4ca8d8bc344beb726298af 100644 (file)
@@ -1359,7 +1359,10 @@ int TermDb::getQAttrRewriteRulePriority( Node q ) {
 
 
 
-
+TermDbSygus::TermDbSygus(){
+  d_true = NodeManager::currentNM()->mkConst( true );
+  d_false = NodeManager::currentNM()->mkConst( false );
+}
 
 TNode TermDbSygus::getVar( TypeNode tn, int i ) {
   while( i>=(int)d_fv[tn].size() ){
@@ -1534,7 +1537,7 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int
   }
   Trace("sygus-db") << "mkGeneric " << dt.getName() << " " << op << " " << op.getKind() << "..." << std::endl;
   for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){
-    TypeNode tna = TypeNode::fromType( ((SelectorType)dt[c][i].getType()).getRangeType() );
+    TypeNode tna = getArgType( dt[c], i );
     Node a;
     std::map< int, Node >::iterator it = pre.find( i );
     if( it!=pre.end() ){
@@ -1589,14 +1592,18 @@ Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) {
   }
 }
 
-Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) {
+//rcons_depth limits the number of recursive calls when doing accelerated constant reconstruction (currently limited to 1000)
+//this is hacky : depending upon order of calls, constant rcons may succeed, e.g. 1001, 999 vs. 999, 1001
+Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) {
   std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c );
   if( it==d_builtin_const_to_sygus[tn].end() ){
+    Node sc;
+    d_builtin_const_to_sygus[tn][c] = sc;
     Assert( c.isConst() );
     Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) );
     const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
+    Trace("csi-rcons-debug") << "Try to reconstruct " << c << " in " << dt.getName() << std::endl;
     Assert( dt.isSygus() );
-    Node sc;
     // if we are not interested in reconstructing constants, or the grammar allows them, return a proxy
     if( !options::cegqiSingleInvReconstructConst() || dt.getSygusAllowConst() ){
       Node k = NodeManager::currentNM()->mkSkolem( "sy", tn, "sygus proxy" );
@@ -1606,9 +1613,60 @@ Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) {
     }else{
       int carg = getOpArg( tn, c );
       if( carg!=-1 ){
-        sc = Node::fromExpr( dt[carg].getSygusOp() );
+        //sc = Node::fromExpr( dt[carg].getSygusOp() );
+        sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) );
       }else{
-        //TODO
+        //identity functions
+        for( unsigned i=0; i<getNumIdFuncs( tn ); i++ ){
+          unsigned ii = getIdFuncIndex( tn, i );
+          Assert( dt[ii].getNumArgs()==1 );
+          //try to directly reconstruct from single argument
+          TypeNode tnc = getArgType( dt[ii], 0 );
+          Trace("csi-rcons-debug") << "Based on id function " << dt[ii].getSygusOp() << ", try reconstructing " << c << " instead in " << tnc << std::endl;
+          Node n = builtinToSygusConst( c, tnc, rcons_depth );
+          if( !n.isNull() ){
+            sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[ii].getConstructor() ), n );
+            break;
+          }
+        }
+        if( sc.isNull() ){
+          if( rcons_depth<1000 ){
+            //accelerated, recursive reconstruction of constants
+            Kind pk = getPlusKind( TypeNode::fromType( dt.getSygusType() ) );
+            if( pk!=UNDEFINED_KIND ){
+              int arg = getKindArg( tn, pk );
+              if( arg!=-1 ){
+                Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) );
+                Kind pkm = getPlusKind( TypeNode::fromType( dt.getSygusType() ), true );
+                //get types
+                Assert( dt[arg].getNumArgs()==2 );
+                TypeNode tn1 = getArgType( dt[arg], 0 );
+                TypeNode tn2 = getArgType( dt[arg], 1 );
+                //iterate over all positive constants, largest to smallest
+                int start = d_const_list[tn1].size()-1;
+                int end = d_const_list[tn1].size()-d_const_list_pos[tn1];
+                for( int i=start; i>=end; --i ){
+                  Node c1 = d_const_list[tn1][i];
+                  //only consider if smaller than c, and 
+                  if( doCompare( c1, c, ck ) ){
+                    Node c2 = NodeManager::currentNM()->mkNode( pkm, c, c1 );
+                    c2 = Rewriter::rewrite( c2 );
+                    if( c2.isConst() ){
+                      //reconstruct constant on the other side
+                      Node sc2 = builtinToSygusConst( c2, tn2, rcons_depth+1 );
+                      if( !sc2.isNull() ){
+                        Node sc1 = builtinToSygusConst( c1, tn1, rcons_depth );
+                        Assert( !sc1.isNull() );
+                        sc = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[arg].getConstructor() ), sc1, sc2 );
+                        break;
+                      }
+                    }
+                  }
+                }
+              }
+            }
+          }
+        }
       }
     }
     d_builtin_const_to_sygus[tn][c] = sc;
@@ -1798,7 +1856,7 @@ Node TermDbSygus::getTypeValue( TypeNode tn, int val ) {
       n = NodeManager::currentNM()->mkConst<BitVector>(bval);
     }else if( tn.isBoolean() ){
       if( val==0 ){
-        n = NodeManager::currentNM()->mkConst( false );
+        n = d_false;
       }
     }
     d_type_value[tn][val] = n;
@@ -1815,7 +1873,7 @@ Node TermDbSygus::getTypeMaxValue( TypeNode tn ) {
     if( tn.isBitVector() ){
       n = bv::utils::mkOnes(tn.getConst<BitVectorSize>());
     }else if( tn.isBoolean() ){
-      n = NodeManager::currentNM()->mkConst( true );
+      n = d_true;
     }
     d_type_max_value[tn] = n;
     return n;
@@ -1847,6 +1905,18 @@ Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& st
   }
 }
 
+struct sortConstants {
+  TermDbSygus * d_tds;
+  Kind d_comp_kind;
+  bool operator() (Node i, Node j) {
+    if( i!=j ){
+      return d_tds->doCompare( i, j, d_comp_kind );
+    }else{
+      return false;
+    }
+  }
+};
+
 void TermDbSygus::registerSygusType( TypeNode tn ){
   if( d_register.find( tn )==d_register.end() ){
     if( !datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){
@@ -1858,6 +1928,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ){
       if( d_register[tn].isNull() ){
         Trace("sygus-util") << "...not sygus." << std::endl;
       }else{
+        //for constant reconstruction
+        Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) );
+        Node z = getTypeValue( TypeNode::fromType( dt.getSygusType() ), 0 );
+        d_const_list_pos[tn] = 0;
+        //iterate over constructors
         for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
           Expr sop = dt[i].getSygusOp();
           Assert( !sop.isNull() );
@@ -1872,11 +1947,39 @@ void TermDbSygus::registerSygusType( TypeNode tn ){
             Trace("sygus-util") << ", constant";
             d_consts[tn][n] = i;
             d_arg_const[tn][i] = n;
+            d_const_list[tn].push_back( n );
+            if( ck!=UNDEFINED_KIND && doCompare( z, n, ck ) ){
+              d_const_list_pos[tn]++;
+            }
+          }
+          if( dt[i].isSygusIdFunc() ){
+            d_id_funcs[tn].push_back( i );
           }
           d_ops[tn][n] = i;
           d_arg_ops[tn][i] = n;
           Trace("sygus-util") << std::endl;
         }
+        //sort the constant list
+        if( !d_const_list[tn].empty() ){
+          if( ck!=UNDEFINED_KIND ){
+            sortConstants sc;
+            sc.d_comp_kind = ck;
+            sc.d_tds = this;
+            std::sort( d_const_list[tn].begin(), d_const_list[tn].end(), sc );
+          }
+          Trace("sygus-util") << "Type has " << d_const_list[tn].size() << " constants..." << std::endl << "  ";
+          for( unsigned i=0; i<d_const_list[tn].size(); i++ ){
+            Trace("sygus-util") << d_const_list[tn][i] << " ";
+          }
+          Trace("sygus-util") << std::endl;
+          Trace("sygus-util") << "Of these, " << d_const_list_pos[tn] << " are marked as positive." << std::endl;
+        }
+        //register connected types
+        for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
+          for( unsigned j=0; j<dt[i].getNumArgs(); j++ ){
+            registerSygusType( getArgType( dt[i], j ) );
+          }
+        }
       }
     }
   }
@@ -1979,6 +2082,14 @@ bool TermDbSygus::isConstArg( TypeNode tn, int i ) {
   }
 }
 
+unsigned TermDbSygus::getNumIdFuncs( TypeNode tn ) {
+  return d_id_funcs[tn].size();
+}
+
+unsigned TermDbSygus::getIdFuncIndex( TypeNode tn, unsigned i ) {
+  return d_id_funcs[tn][i];
+}
+
 TypeNode TermDbSygus::getArgType( const DatatypeConstructor& c, int i ) {
   Assert( i>=0 && i<(int)c.getNumArgs() );
   return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() );
@@ -2041,7 +2152,33 @@ Node TermDbSygus::expandBuiltinTerm( Node t ){
 }
 
 
-void doReplace(std::string& str, const std::string& oldStr, const std::string& newStr){
+Kind TermDbSygus::getComparisonKind( TypeNode tn ) {
+  if( tn.isInteger() ){
+    return LT;
+  }else if( tn.isBitVector() ){
+    return BITVECTOR_ULT;
+  }else{
+    return UNDEFINED_KIND;
+  }
+}
+
+Kind TermDbSygus::getPlusKind( TypeNode tn, bool is_neg ) {
+  if( tn.isInteger() ){
+    return is_neg ? MINUS : PLUS;
+  }else if( tn.isBitVector() ){
+    return is_neg ? BITVECTOR_SUB : BITVECTOR_PLUS;
+  }else{
+    return UNDEFINED_KIND;
+  }
+}
+
+bool TermDbSygus::doCompare( Node a, Node b, Kind k ) {
+  Node com = NodeManager::currentNM()->mkNode( k, a, b );
+  com = Rewriter::rewrite( com );
+  return com==d_true;
+}
+
+void doStrReplace(std::string& str, const std::string& oldStr, const std::string& newStr){
   size_t pos = 0;
   while((pos = str.find(oldStr, pos)) != std::string::npos){
      str.replace(pos, oldStr.length(), newStr);
@@ -2101,12 +2238,16 @@ void TermDbSygus::printSygusTerm( std::ostream& out, Node n, std::vector< Node >
         std::stringstream body_out;
         printSygusTerm( body_out, let_body, new_lvs );
         std::string body = body_out.str();
-        for( unsigned i=dt[cIndex].getNumSygusLetInputArgs(); i<dt[cIndex].getNumSygusLetArgs(); i++ ){ 
+        for( unsigned i=0; i<dt[cIndex].getNumSygusLetArgs(); i++ ){ 
           std::stringstream old_str;
           old_str << new_lvs[i];
           std::stringstream new_str;
-          printSygusTerm( new_str, n[i], lvs );
-          doReplace( body, old_str.str().c_str(), new_str.str().c_str() );
+          if( i>=dt[cIndex].getNumSygusLetInputArgs() ){
+            printSygusTerm( new_str, n[i], lvs );
+          }else{
+            new_str << Node::fromExpr( dt[cIndex].getSygusLetArg( i ) );
+          }
+          doStrReplace( body, old_str.str().c_str(), new_str.str().c_str() );
         }
         out << body;
         if( dt[cIndex].getNumSygusLetInputArgs()>0 ){
index 0bb2c32246c76abe8ccafc771e01bba8692923c3..455287feb7401f23bb7ba59ec32c8f50d4b1bf31 100644 (file)
@@ -382,6 +382,8 @@ private:
   std::map< TypeNode, std::vector< Node > > d_fv;
   std::map< Node, TypeNode > d_fv_stype;
   std::map< Node, int > d_fv_num;
+  Node d_true;
+  Node d_false;
 public:
   TNode getVar( TypeNode tn, int i );
   TNode getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count );
@@ -404,6 +406,9 @@ private:
   std::map< TypeNode, std::map< Node, int > > d_consts;
   std::map< TypeNode, std::map< Node, int > > d_ops;
   std::map< TypeNode, std::map< int, Node > > d_arg_ops;
+  std::map< TypeNode, std::vector< int > > d_id_funcs;
+  std::map< TypeNode, std::vector< Node > > d_const_list; //sorted list of constants for type
+  std::map< TypeNode, unsigned > d_const_list_pos;
   //information for builtin types
   std::map< TypeNode, std::map< int, Node > > d_type_value;
   std::map< TypeNode, Node > d_type_max_value;
@@ -414,7 +419,7 @@ private:
   std::map< TypeNode, std::map< Node, Node > > d_sygus_to_builtin;
   std::map< TypeNode, std::map< Node, Node > > d_builtin_const_to_sygus;
 public:
-  TermDbSygus(){}
+  TermDbSygus();
   bool isRegistered( TypeNode tn );
   int getKindArg( TypeNode tn, Kind k );
   int getConstArg( TypeNode tn, Node n );
@@ -427,6 +432,8 @@ public:
   Kind getArgKind( TypeNode tn, int i );
   bool isKindArg( TypeNode tn, int i );
   bool isConstArg( TypeNode tn, int i );
+  unsigned getNumIdFuncs( TypeNode tn );
+  unsigned getIdFuncIndex( TypeNode tn, unsigned i );
   void registerSygusType( TypeNode tn );
   /** get arg type */
   TypeNode getArgType( const DatatypeConstructor& c, int i );
@@ -451,7 +458,7 @@ public:
   TypeNode getSygusType( Node v );
   Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre );
   Node sygusToBuiltin( Node n, TypeNode tn );
-  Node builtinToSygusConst( Node c, TypeNode tn );
+  Node builtinToSygusConst( Node c, TypeNode tn, int rcons_depth = 0 );
   Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs );
   Node getNormalized( TypeNode t, Node prog, bool do_pre_norm = false, bool do_post_norm = true );
   int getTermSize( Node n );
@@ -459,6 +466,10 @@ public:
   Node minimizeBuiltinTerm( Node n );
   /** given a term, expand it into more basic components */
   Node expandBuiltinTerm( Node n );
+  /** get comparison kind */
+  Kind getComparisonKind( TypeNode tn );
+  Kind getPlusKind( TypeNode tn, bool is_neg = false );
+  bool doCompare( Node a, Node b, Kind k );
   /** print sygus term */
   static void printSygusTerm( std::ostream& out, Node n, std::vector< Node >& lvs );
 };
index b1ab011efd9c08d6ca40d8614a7e360a86f2839e..138fb4bb05d0d8803394286fd9af3c9c1096fccd 100644 (file)
@@ -707,6 +707,11 @@ unsigned DatatypeConstructor::getNumSygusLetInputArgs() const {
   CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved");
   return d_sygus_num_let_input_args;
 }
+
+bool DatatypeConstructor::isSygusIdFunc() const {
+  CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved");
+  return d_sygus_let_args.size()==1 && d_sygus_let_args[0]==d_sygus_let_body;
+}
   
 Cardinality DatatypeConstructor::getCardinality() const throw(IllegalArgumentException) {
   CheckArgument(isResolved(), this, "this datatype constructor is not yet resolved");
index 1945c4390fea0bc28a0a872427aa165067168af2..0b8b8c61f881bd7aab1d33ff3befd89435b24b44 100644 (file)
@@ -294,6 +294,8 @@ public:
   Expr getSygusLetArg( unsigned i ) const;
   /** get number of let arguments that should be printed as arguments to let */
   unsigned getNumSygusLetInputArgs() const;
+  /** is this a sygus identity function */
+  bool isSygusIdFunc() const;
   
   /**
    * Get the tester name for this Datatype constructor.