Add module for sygus search space narrowing based on global state.
authorajreynol <andrew.j.reynolds@gmail.com>
Thu, 29 Jan 2015 07:27:18 +0000 (08:27 +0100)
committerajreynol <andrew.j.reynolds@gmail.com>
Thu, 29 Jan 2015 07:27:18 +0000 (08:27 +0100)
src/theory/datatypes/datatypes_sygus.cpp
src/theory/datatypes/datatypes_sygus.h
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h

index bf17cf5e4a4b63cce4083e2d5ef1f620485edf63..8bf3d4d64bb32cdc79db0cab9595286a2695a3b4 100644 (file)
@@ -96,40 +96,6 @@ bool SygusSplit::isConstArg( TypeNode tn, int i ) {
   }
 }
 
-Node SygusSplit::getVar( TypeNode tn, int i ) {
-  while( i>=(int)d_fv[tn].size() ){
-    std::stringstream ss;
-    TypeNode vtn = tn;
-    if( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){
-      const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
-      ss << "fv_" << dt.getName() << "_" << i;
-      Assert( d_register.find( tn )!=d_register.end() );
-      if( !d_register[tn].isNull() ){
-        vtn = d_register[tn];
-      }
-    }else{
-      ss << "fv_" << tn << "_" << i;
-    }
-    Assert( !vtn.isNull() );
-    Node v = NodeManager::currentNM()->mkSkolem( ss.str(), vtn, "for sygus normal form testing" );
-    d_fv_stype[v] = tn;
-    d_fv[tn].push_back( v );
-  }
-  return d_fv[tn][i];
-}
-
-Node SygusSplit::getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count ) {
-  std::map< TypeNode, int >::iterator it = var_count.find( tn );
-  if( it==var_count.end() ){
-    var_count[tn] = 1;
-    return getVar( tn, 0 );
-  }else{
-    int index = it->second;
-    var_count[tn]++;
-    return getVar( tn, index );
-  }
-}
-
 void SygusSplit::getSygusSplits( Node n, const Datatype& dt, std::vector< Node >& splits, std::vector< Node >& lemmas ) {
   Assert( dt.isSygus() );
   if( d_splits.find( n )==d_splits.end() ){
@@ -180,18 +146,6 @@ void SygusSplit::getSygusSplits( Node n, const Datatype& dt, std::vector< Node >
           addSplit = d_sygus_pc_nred[tnn][csIndex][sIndex][i];
         }
         if( addSplit ){
-          //check based on generic rewriting  TODO
-          //std::vector< int > csIndices;
-          //std::vector< int > sIndices;
-          //csIndices.push_back( i );
-          //TypeNode tng;
-          //Node g = getGeneric( n, csIndices, sIndices, tng );
-          //Trace("sygus-split-debug") << "Generic template " << n << " " << dt[i].getName() << " is " << g << ", sygus type : " << tng << std::endl;
-          //if( isGenericRedundant( tng, g ) ){
-          //  addSplit = false;
-          //  Trace("sygus-split-debug2") << "generic redundant" << std::endl;
-          //}
-
           std::vector< Node > test_c;
           Node test = DatatypesRewriter::mkTester( n, i, dt );
           test_c.push_back( test );
@@ -323,7 +277,7 @@ void SygusSplit::registerSygusType( TypeNode tn ) {
               Trace("sygus-split-debug") << "Check " << dt[i].getName() << " based on generic rewriting" << std::endl;
               std::map< TypeNode, int > var_count;
               std::map< int, Node > pre;
-              Node g = mkGeneric( dt, i, var_count, pre );
+              Node g = d_util->mkGeneric( dt, i, var_count, pre );
               nred = !isGenericRedundant( tn, g );
               Trace("sygus-split-debug") << "...done check " << dt[i].getName() << " based on generic rewriting" << std::endl;
             }
@@ -378,10 +332,10 @@ void SygusSplit::registerSygusTypeConstructorArg( TypeNode tnn, const Datatype&
               //generic rewriting
               std::map< int, Node > prec;
               std::map< TypeNode, int > var_count;
-              Node gc = mkGeneric( dt, i, var_count, prec );
+              Node gc = d_util->mkGeneric( dt, i, var_count, prec );
               std::map< int, Node > pre;
               pre[sIndex] = gc;
-              Node g = mkGeneric( pdt, csIndex, var_count, pre );
+              Node g = d_util->mkGeneric( pdt, csIndex, var_count, pre );
               addSplit = !isGenericRedundant( tnnp, g );
             }
           }
@@ -433,8 +387,8 @@ void SygusSplit::registerSygusTypeConstructorArg( TypeNode tnn, const Datatype&
                           //&& dto[i].getNumArgs()==0 && dt[j].getNumArgs()==0 ){
                           std::map< TypeNode, int > var_count;
                           std::map< int, Node > pre;
-                          Node g1 = mkGeneric( dt, j, var_count, pre );
-                          Node g2 = mkGeneric( dto, i, var_count, pre );
+                          Node g1 = d_util->mkGeneric( dt, j, var_count, pre );
+                          Node g2 = d_util->mkGeneric( dto, i, var_count, pre );
                           Node g = NodeManager::currentNM()->mkNode( parentKind, g1, g2 );
                           if( isGenericRedundant( tnnp, g ) ){
                             rem = true;
@@ -754,50 +708,191 @@ TypeNode SygusSplit::getArgType( const DatatypeConstructor& c, int i ) {
   return TypeNode::fromType( ((SelectorType)c[i].getType()).getRangeType() );
 }
 
-Node SygusSplit::getGeneric( Node n, std::vector< int >& csIndices, std::vector< int >& sIndices, TypeNode& tng ) {
-  if( n.getKind()==APPLY_SELECTOR_TOTAL ){
-    Node op = n.getOperator();
-    Expr selectorExpr = op.toExpr();
-    csIndices.push_back( Datatype::cindexOf(selectorExpr) );
-    sIndices.push_back( Datatype::indexOf(selectorExpr) );
-    return getGeneric( n[0], csIndices, sIndices, tng );
-  }else{
-    tng = n.getType();
-    Assert( DatatypesRewriter::isTypeDatatype( tng ) );
-    const Datatype& dt = ((DatatypeType)(tng).toType()).getDatatype();
-    Assert( csIndices.size()==sIndices.size()+1 );
-    std::reverse( csIndices.begin(), csIndices.end() );
-    std::reverse( sIndices.begin(), sIndices.end() );
-    Trace("sygus-generic") << "Traversed under " << sIndices.size() << " selectors." << std::endl;
+bool SygusSplit::isGenericRedundant( TypeNode tn, Node g ) {
+  //everything added to this cache should be mutually exclusive cases
+  std::map< Node, bool >::iterator it = d_gen_redundant[tn].find( g );
+  if( it==d_gen_redundant[tn].end() ){
+    Trace("sygus-gnf") << "Register generic for " << tn << " : " << g << std::endl;
+    Node gr = Rewriter::rewrite( g );
+    //replace variables in order left to right
     std::map< TypeNode, int > var_count;
-    return getGeneric2( dt, var_count, csIndices, sIndices, 0 );
+    std::map< Node, Node > subs;
+    gr = d_util->getSygusNormalized( gr, var_count, subs );
+    Trace("sygus-gnf-debug") << "Generic " << g << " rewrites to " << gr << std::endl;
+    std::map< Node, Node >::iterator itg = d_gen_terms[tn].find( gr );
+    bool red = true;
+    if( itg==d_gen_terms[tn].end() ){
+      red = false;
+      d_gen_terms[tn][gr] = g;
+      Trace("sygus-gnf-debug") << "...not redundant." << std::endl;
+    }else{
+      Trace("sygus-gnf-debug") << "...redundant." << std::endl;
+      Trace("sygus-nf") << "* Sygus normal form : simplify since " << g << " and " << itg->second << " both rewrite to " << gr << std::endl;
+    }
+    d_gen_redundant[tn][g] = red;
+    return red;
+  }else{
+    return it->second;
   }
 }
 
-Node SygusSplit::getGeneric2( const Datatype& dt, std::map< TypeNode, int >& var_count, std::vector< int >& csIndices, std::vector< int >& sIndices, unsigned index ) {
-  Assert( index<csIndices.size() );
-  std::vector< Node > children;
-  int c = csIndices[index];
-  int s = index<sIndices.size() ? sIndices[index] : -1;
-  Assert( c>=0 && c<(int)dt.getNumConstructors() );
-  Assert( dt.isSygus() );
-  Assert( !dt[c].getSygusOp().isNull() );
-  Node op = Node::fromExpr( dt[c].getSygusOp() );
-  if( op.getKind()!=BUILTIN ){
-    children.push_back( op );
+
+
+SygusSymBreak::SygusSymBreak( SygusUtil * util, context::Context* c ) :
+d_util( util ), d_testers( c ), d_watched_terms( c ), d_watched_count( c ), d_anchor( c ), d_prog_depth( c, 0 ), d_conflict( c ) {
+
+
+}
+
+void SygusSymBreak::addTester( Node tst ) {
+  NodeMap::const_iterator it = d_testers.find( tst[0] );
+  if( it==d_testers.end() ){
+    d_testers[tst[0]] = tst;
+    if( d_anchor.get().isNull() ){
+      if( tst[0].getKind()!=APPLY_SELECTOR_TOTAL ){
+        d_anchor = tst[0];
+        assignTester( tst, 0 );
+      }
+    }else{
+      IntMap::const_iterator it = d_watched_terms.find( tst[0] );
+      if( it!=d_watched_terms.end() ){
+        assignTester( tst, (*it).second );
+      }
+    }
   }
-  Trace("sygus-generic") << "Construct for " << dt[c].getName() << ", arg " << s << ", op " << op << std::endl;
+}
+
+void SygusSymBreak::assignTester( Node tst, int depth ) {
+  Trace("sygus-sym-break-debug") << "SymBreak : Assign tester : " << tst << ", depth = " << depth << std::endl;
+  int tindex = Datatype::indexOf( tst.getOperator().toExpr() );
+  TypeNode tn = tst[0].getType();
+  Assert( DatatypesRewriter::isTypeDatatype( tn ) );
+  const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
+  std::vector< Node > tst_waiting;
+  for( unsigned i=0; i<dt[tindex].getNumArgs(); i++ ){
+    Node sel = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex][i].getSelector() ), tst[0] );
+    NodeMap::const_iterator it = d_testers.find( sel );
+    if( it!=d_testers.end() ){
+      tst_waiting.push_back( (*it).second );
+    }else{
+      Trace("sygus-sym-break-debug") << "...add " << sel << " as watch term for " << (depth+1) << std::endl;
+      d_watched_terms[sel] = depth+1;
+    }
+  }
+  //update watched count
+  IntIntMap::const_iterator it = d_watched_count.find( depth+1 );
+  if( it==d_watched_count.end() ){
+    d_watched_count[depth+1] = dt[tindex].getNumArgs();
+  }else{
+    d_watched_count[depth+1] = d_watched_count[depth+1] + dt[tindex].getNumArgs();
+  }
+  Trace("sygus-sym-break-debug") << "...watched count now " << d_watched_count[depth+1].get() << " for " << (depth+1) << std::endl;
+  //now decrement watch count and process
+  if( depth>0 ){
+    Assert( d_watched_count[depth]>0 );
+    d_watched_count[depth] = d_watched_count[depth] - 1;
+  }
+  processProgramDepth( depth );
+  //assign preexisting testers
+  for( unsigned i=0; i<tst_waiting.size(); i++ ){
+    assignTester( tst_waiting[i], depth+1 );
+  }
+}
+
+Node SygusSymBreak::getCandidateProgramAtDepth( int depth, Node prog, int curr_depth, std::map< TypeNode, int >& var_count, std::vector< Node >& testers ) {
+  Assert( depth>=curr_depth );
+  Trace("sygus-sym-break-debug") << "Reconstructing program for " << prog << " at depth " << curr_depth << "/" << depth << std::endl;
+  NodeMap::const_iterator it = d_testers.find( prog );
+  Assert( it!=d_testers.end() );
+  Node tst = (*it).second;
+  testers.push_back( tst );
+  Assert( tst[0]==prog );
+  int tindex = Datatype::indexOf( tst.getOperator().toExpr() );
+  TypeNode tn = prog.getType();
+  Assert( DatatypesRewriter::isTypeDatatype( tn ) );
+  const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
   std::map< int, Node > pre;
-  if( s!=-1 ){
-    TypeNode tna = getArgType( dt[c], s );
-    Assert( DatatypesRewriter::isTypeDatatype( tna ) );
-    const Datatype& adt = ((DatatypeType)(tna).toType()).getDatatype();
-    pre[s] = getGeneric2( adt, var_count, csIndices, sIndices, index+1 );
+  if( curr_depth<depth ){
+    for( unsigned i=0; i<dt[tindex].getNumArgs(); i++ ){
+      Node sel = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex][i].getSelector() ), prog );
+      pre[i] = getCandidateProgramAtDepth( depth, sel, curr_depth+1, var_count, testers );
+    }
+  }
+  return d_util->mkGeneric( dt, tindex, var_count, pre );
+}
+
+void SygusSymBreak::processProgramDepth( int depth ){
+  if( depth==d_prog_depth.get() && ( depth==0 || ( d_watched_count.find( depth )!=d_watched_count.end() && d_watched_count[depth]==0 ) ) ){
+    d_prog_depth = d_prog_depth + 1;
+    if( depth>0 ){
+      Trace("sygus-sym-break-debug") << "Program is set for depth=" << depth << std::endl;
+      std::map< TypeNode, int > var_count;
+      std::vector< Node > testers;
+      //now have entire information about candidate program at given depth
+      Node prog = getCandidateProgramAtDepth( depth, d_anchor.get(), 0, var_count, testers );
+      if( d_normalized.find( prog )==d_normalized.end() ){
+        Trace("sygus-sym-break") << "Currently considering program : " << prog << " at depth " << depth << std::endl;
+        Node progr = Rewriter::rewrite( prog );
+        std::map< TypeNode, int > var_count;
+        std::map< Node, Node > subs;
+        progr = d_util->getSygusNormalized( progr, var_count, subs );
+        Trace("sygus-sym-break2") << "...rewrites to " << progr << std::endl;
+        d_normalized[prog] = progr;
+        std::map< Node, Node >::iterator it = d_normalized_to_orig.find( progr );
+        if( it==d_normalized_to_orig.end() ){
+          d_normalized_to_orig[progr] = prog;
+        }else{
+          Assert( !testers.empty() );
+          Trace("sygus-nf") << "* Sygus sym break : " << prog << " and " << it->second << " both rewrite to " << progr << std::endl;
+          Node conflict = testers.size()==1 ? testers[0] : NodeManager::currentNM()->mkNode( AND, testers );
+          Trace("sygus-sym-break2") << "Conflict : " << conflict << std::endl;
+        }
+      }
+    }
+    processProgramDepth( depth+1 );
+  }
+}
+
+
+SygusUtil::SygusUtil( Context* c ) {
+  d_split = new SygusSplit( this );
+  d_sym_break = new SygusSymBreak( this, c );
+}
+
+Node SygusUtil::getVar( TypeNode tn, int i ) {
+  while( i>=(int)d_fv[tn].size() ){
+    std::stringstream ss;
+    TypeNode vtn = tn;
+    if( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ){
+      const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype();
+      ss << "fv_" << dt.getName() << "_" << i;
+      if( !dt.getSygusType().isNull() ){
+        vtn = TypeNode::fromType( dt.getSygusType() );
+      }
+    }else{
+      ss << "fv_" << tn << "_" << i;
+    }
+    Assert( !vtn.isNull() );
+    Node v = NodeManager::currentNM()->mkSkolem( ss.str(), vtn, "for sygus normal form testing" );
+    d_fv_stype[v] = tn;
+    d_fv[tn].push_back( v );
+  }
+  return d_fv[tn][i];
+}
+
+Node SygusUtil::getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count ) {
+  std::map< TypeNode, int >::iterator it = var_count.find( tn );
+  if( it==var_count.end() ){
+    var_count[tn] = 1;
+    return getVar( tn, 0 );
+  }else{
+    int index = it->second;
+    var_count[tn]++;
+    return getVar( tn, index );
   }
-  return mkGeneric( dt, c, var_count, pre );
 }
 
-Node SygusSplit::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ) {
+Node SygusUtil::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ) {
   Assert( c>=0 && c<(int)dt.getNumConstructors() );
   Assert( dt.isSygus() );
   Assert( !dt[c].getSygusOp().isNull() );
@@ -807,9 +902,7 @@ Node SygusSplit::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >
     children.push_back( op );
   }
   for( int i=0; i<(int)dt[c].getNumArgs(); i++ ){
-    TypeNode tna = getArgType( dt[c], i );
-    registerSygusType( tna );
-    Assert( d_register.find( tna )!=d_register.end() );
+    TypeNode tna = TypeNode::fromType( ((SelectorType)dt[c][i].getType()).getRangeType() );
     Node a;
     std::map< int, Node >::iterator it = pre.find( i );
     if( it!=pre.end() ){
@@ -818,7 +911,6 @@ Node SygusSplit::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >
       a = getVarInc( tna, var_count );
     }
     Assert( !a.isNull() );
-    Assert( a.getType()==d_register[tna] );
     children.push_back( a );
   }
   if( Trace.isOn("sygus-split-debug3") ){
@@ -838,35 +930,7 @@ Node SygusSplit::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >
   }
 }
 
-bool SygusSplit::isGenericRedundant( TypeNode tn, Node g ) {
-  //everything added to this cache should be mutually exclusive cases
-  std::map< Node, bool >::iterator it = d_gen_redundant[tn].find( g );
-  if( it==d_gen_redundant[tn].end() ){
-    Trace("sygus-gnf") << "Register generic for " << tn << " : " << g << std::endl;
-    Node gr = Rewriter::rewrite( g );
-    //replace variables in order left to right
-    std::map< TypeNode, int > var_count;
-    std::map< Node, Node > subs;
-    gr = getSygusNormalized( gr, var_count, subs );
-    Trace("sygus-gnf-debug") << "Generic " << g << " rewrites to " << gr << std::endl;
-    std::map< Node, Node >::iterator itg = d_gen_terms[tn].find( gr );
-    bool red = true;
-    if( itg==d_gen_terms[tn].end() ){
-      red = false;
-      d_gen_terms[tn][gr] = g;
-      Trace("sygus-gnf-debug") << "...not redundant." << std::endl;
-    }else{
-      Trace("sygus-gnf-debug") << "...redundant." << std::endl;
-      Trace("sygus-nf") << "* Sygus normal form : simplify since " << g << " and " << itg->second << " both rewrite to " << gr << std::endl;
-    }
-    d_gen_redundant[tn][g] = red;
-    return red;
-  }else{
-    return it->second;
-  }
-}
-
-Node SygusSplit::getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ) {
+Node SygusUtil::getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ) {
   return n;
   if( n.getKind()==SKOLEM ){
     std::map< Node, Node >::iterator its = subs.find( n );
index c638e5da699276e7d1f6d1724d36ae93686da0fe..69806e07677df6d1d5c3ac3ca3e49365424a82e7 100644 (file)
@@ -14,6 +14,8 @@
  ** Theory of datatypes.
  **/
 
+#include "cvc4_private.h"
+
 #ifndef __CVC4__THEORY__DATATYPES__DATATYPES_SYGUS_H
 #define __CVC4__THEORY__DATATYPES__DATATYPES_SYGUS_H
 
 #include "util/datatype.h"
 #include <iostream>
 #include <map>
+#include "context/context.h"
 #include "context/cdchunk_list.h"
+#include "context/cdhashmap.h"
+#include "context/cdo.h"
 
 namespace CVC4 {
-namespace theory { 
+namespace theory {
 namespace datatypes {
-  
+
+class SygusUtil;
+
 class SygusSplit
 {
 private:
+  SygusUtil * d_util;
   std::map< Node, std::vector< Node > > d_splits;
   std::map< TypeNode, std::vector< bool > > d_sygus_nred;
   std::map< TypeNode, std::map< int, std::map< int, std::vector< bool > > > > d_sygus_pc_nred;
@@ -44,9 +52,6 @@ private:
   std::map< TypeNode, std::map< int, Node > > d_arg_const;
   std::map< TypeNode, std::map< Node, int > > d_consts;
   std::map< TypeNode, std::map< Node, int > > d_ops;
-  //
-  std::map< TypeNode, std::vector< Node > > d_fv;
-  std::map< Node, TypeNode > d_fv_stype;
   // type to (rewritten) to original
   std::map< TypeNode, std::map< Node, Node > > d_gen_terms;
   std::map< TypeNode, std::map< Node, bool > > d_gen_redundant;
@@ -92,17 +97,62 @@ private:
   /** get arg type */
   TypeNode getArgType( const DatatypeConstructor& c, int i );
 private:
-  Node getGeneric( Node n, std::vector< int >& csIndices, std::vector< int >& sIndices, TypeNode& tng );
-  Node getGeneric2( const Datatype& dt, std::map< TypeNode, int >& var_count, std::vector< int >& csIndices, std::vector< int >& sIndices, unsigned index );
-  Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre );
+  // generic cache
   bool isGenericRedundant( TypeNode tn, Node g );
-  Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs );
 public:
+  SygusSplit( SygusUtil * util ) : d_util( util ) {}
   /** get sygus splits */
   void getSygusSplits( Node n, const Datatype& dt, std::vector< Node >& splits, std::vector< Node >& lemmas );
 };
 
 
+
+
+class SygusSymBreak
+{
+  typedef context::CDHashMap< Node, Node, NodeHashFunction > NodeMap;
+  typedef context::CDHashMap< Node, int, NodeHashFunction > IntMap;
+  typedef context::CDHashMap< int, int > IntIntMap;
+private:
+  SygusUtil * d_util;
+  NodeMap d_testers;
+  IntMap d_watched_terms;
+  IntIntMap d_watched_count;
+  context::CDO<Node> d_anchor;
+  context::CDO<int> d_prog_depth;
+  std::map< Node, Node > d_normalized;
+  std::map< Node, Node > d_normalized_to_orig;
+  void assignTester( Node tst, int depth );
+  Node getCandidateProgramAtDepth( int depth, Node prog, int curr_depth, std::map< TypeNode, int >& var_count, std::vector< Node >& testers );
+  void processProgramDepth( int depth );
+  context::CDO<Node> d_conflict;
+public:
+  SygusSymBreak( SygusUtil * util, context::Context* c );
+  /** add tester */
+  void addTester( Node tst );
+};
+
+class SygusUtil
+{
+  friend class SygusSplit;
+  friend class SygusSymBreak;
+private:
+  std::map< TypeNode, std::vector< Node > > d_fv;
+  std::map< Node, TypeNode > d_fv_stype;
+  SygusSplit * d_split;
+  SygusSymBreak * d_sym_break;
+private:
+  Node getVar( TypeNode tn, int i );
+  Node getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count );
+  Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre );
+  Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs );
+public:
+  SygusUtil( context::Context* c );
+  SygusSplit * getSplit() { return d_split; }
+  SygusSymBreak * getSymBreak() { return d_sym_break; }
+};
+
+
 }
 }
 }
index f9a451459b7651e5f037a3b85486bcd96dcc2f31..f38545817baf0917030041e44106ed8d9849edd3 100644 (file)
@@ -68,9 +68,9 @@ TheoryDatatypes::TheoryDatatypes(Context* c, UserContext* u, OutputChannel& out,
   d_dtfCounter = 0;
 
   if( options::ceGuidedInst() ){
-    d_sygus_split = new SygusSplit;
+    d_sygus_util = new SygusUtil( c );
   }else{
-    d_sygus_split = NULL;
+    d_sygus_util = NULL;
   }
 }
 
@@ -246,9 +246,9 @@ void TheoryDatatypes::check(Effort e) {
                 }else{
                   Trace("dt-split") << "*************Split for constructors on " << n <<  endl;
                   std::vector< Node > children;
-                  if( dt.isSygus() && d_sygus_split ){
+                  if( dt.isSygus() && d_sygus_util ){
                     std::vector< Node > lemmas;
-                    d_sygus_split->getSygusSplits( n, dt, children, lemmas );
+                    d_sygus_util->getSplit()->getSygusSplits( n, dt, children, lemmas );
                     for( unsigned i=0; i<lemmas.size(); i++ ){
                       Trace("dt-lemma-sygus") << "Dt sygus lemma : " << lemmas[i] << std::endl;
                       d_out->lemma( lemmas[i] );
@@ -362,6 +362,9 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){
   if( atom.getKind()==kind::APPLY_TESTER ){
     if( polarity ){
       Trace("dt-tester") << "Assert tester : " << atom << std::endl;
+      if( d_sygus_util ){
+        d_sygus_util->getSymBreak()->addTester( atom );
+      }
     }
     Node rep = getRepresentative( atom[0] );
     EqcInfo* eqc = getOrMakeEqcInfo( rep, true );
index e5b9d32bbfab0f6d52005cc29a99641274b51f4f..8943688fba4c1b2b00f0c542c175e365f20f7ece 100644 (file)
@@ -22,7 +22,6 @@
 #include "theory/theory.h"
 #include "util/datatype.h"
 #include "util/hash.h"
-#include "util/trans_closure.h"
 #include "theory/uf/equality_engine.h"
 
 #include <ext/hash_set>
@@ -34,7 +33,7 @@ namespace CVC4 {
 namespace theory {
 namespace datatypes {
 
-class SygusSplit;
+class SygusUtil;
 
 class TheoryDatatypes : public Theory {
 private:
@@ -179,8 +178,8 @@ private:
   unsigned d_dtfCounter;
   /** expand definition skolem functions */
   std::map< Node, Node > d_exp_def_skolem;
-  /** sygus split utility */
-  SygusSplit * d_sygus_split;
+  /** sygus utility */
+  SygusUtil * d_sygus_util;
 private:
   /** assert fact */
   void assertFact( Node fact, Node exp );