Option to use sampling for CEGIS (#1555)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 3 Feb 2018 03:04:49 +0000 (21:04 -0600)
committerGitHub <noreply@github.com>
Sat, 3 Feb 2018 03:04:49 +0000 (21:04 -0600)
src/options/options_handler.cpp
src/options/options_handler.h
src/options/quantifiers_modes.h
src/options/quantifiers_options
src/theory/datatypes/datatypes_sygus.cpp
src/theory/quantifiers/ce_guided_conjecture.cpp
src/theory/quantifiers/ce_guided_conjecture.h
src/theory/quantifiers/ce_guided_instantiation.cpp
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index c29cfc4d2ee0d2ca879b1893f59575ac44063a3c..61f7646eed28c617a60e75e42f087f9d3d2de32f 100644 (file)
@@ -492,6 +492,25 @@ all \n\
 \n\
 ";
 
+const std::string OptionsHandler::s_cegisSampleHelp =
+    "\
+Modes for sampling with counterexample-guided inductive synthesis (CEGIS),\
+supported by --cegis-sample:\n\
+\n\
+none (default) \n\
++ Do not use sampling with CEGIS.\n\
+\n\
+use \n\
++ Use sampling to accelerate CEGIS. This will rule out solutions for a\
+  conjecture when they are not satisfied by a sample point.\n\
+\n\
+trust  \n\
++ Trust that when a solution for a conjecture is always true under sampling,\
+  then it is indeed a solution. Note this option may print out spurious\
+  solutions for synthesis conjectures.\n\
+\n\
+";
+
 const std::string OptionsHandler::s_sygusInvTemplHelp = "\
 Template modes for sygus invariant synthesis, supported by --sygus-inv-templ:\n\
 \n\
@@ -877,6 +896,34 @@ OptionsHandler::stringToCegqiSingleInvMode(std::string option,
   }
 }
 
+theory::quantifiers::CegisSampleMode OptionsHandler::stringToCegisSampleMode(
+    std::string option, std::string optarg)
+{
+  if (optarg == "none")
+  {
+    return theory::quantifiers::CEGIS_SAMPLE_NONE;
+  }
+  else if (optarg == "use")
+  {
+    return theory::quantifiers::CEGIS_SAMPLE_USE;
+  }
+  else if (optarg == "trust")
+  {
+    return theory::quantifiers::CEGIS_SAMPLE_TRUST;
+  }
+  else if (optarg == "help")
+  {
+    puts(s_cegisSampleHelp.c_str());
+    exit(1);
+  }
+  else
+  {
+    throw OptionException(std::string("unknown option for --cegis-sample: `")
+                          + optarg
+                          + "'.  Try --cegis-sample help.");
+  }
+}
+
 theory::quantifiers::SygusInvTemplMode
 OptionsHandler::stringToSygusInvTemplMode(std::string option,
                                           std::string optarg)
index e7bd87ebdca817c705d679ec40007dd39261d26e..304009a98f8e7f06764023fa05943fb0d4f1ebc3 100644 (file)
@@ -108,6 +108,8 @@ public:
       std::string option, std::string optarg);
   theory::quantifiers::CegqiSingleInvMode stringToCegqiSingleInvMode(
       std::string option, std::string optarg);
+  theory::quantifiers::CegisSampleMode stringToCegisSampleMode(
+      std::string option, std::string optarg);
   theory::quantifiers::SygusInvTemplMode stringToSygusInvTemplMode(
       std::string option, std::string optarg);
   theory::quantifiers::MacrosQuantMode stringToMacrosQuantMode(
@@ -243,6 +245,7 @@ public:
   static const std::string s_sygusSolutionOutModeHelp;
   static const std::string s_cbqiBvIneqModeHelp;
   static const std::string s_cegqiSingleInvHelp;
+  static const std::string s_cegisSampleHelp;
   static const std::string s_sygusInvTemplHelp;
   static const std::string s_termDbModeHelp;
   static const std::string s_theoryOfModeHelp;
index 6274269ce8b0933378fc79e902d84ca4cc9a97f0..91fab54ff9709dba13138ccc0383170368e1eb41 100644 (file)
@@ -216,6 +216,16 @@ enum CegqiSingleInvMode {
   CEGQI_SI_MODE_ALL,
 };
 
+enum CegisSampleMode
+{
+  /** do not use samples for CEGIS */
+  CEGIS_SAMPLE_NONE,
+  /** use samples for CEGIS */
+  CEGIS_SAMPLE_USE,
+  /** trust samples for CEGQI */
+  CEGIS_SAMPLE_TRUST,
+};
+
 enum SygusInvTemplMode {
   /** synthesize I( x ) */
   SYGUS_INV_TEMPL_MODE_NONE,
index 96d73feebd03003da5727a91730726f5741315c3..34af81033f4e32095a905077613f6fc083dc9159 100644 (file)
@@ -297,6 +297,9 @@ option sygusCRefEvalMinExp --sygus-cref-eval-min-exp bool :default true
 
 option sygusStream --sygus-stream bool :read-write :default false
   enumerate a stream of solutions instead of terminating after the first one
+  
+option cegisSample --cegis-sample=MODE CVC4::theory::quantifiers::CegisSampleMode :read-write :default CVC4::theory::quantifiers::CEGIS_SAMPLE_NONE :include "options/quantifiers_modes.h" :handler stringToCegisSampleMode
+  mode for using samples in the counterexample-guided inductive synthesis loop
 
 # internal uses of sygus
 option sygusRewSynth --sygus-rr-synth bool :default false
@@ -323,6 +326,10 @@ option cbqiMultiInst --cbqi-multi-inst bool :read-write :default false
  when applicable, do multi instantiations per quantifier per round in counterexample-based quantifier instantiation
 option cbqiRepeatLit --cbqi-repeat-lit bool :read-write :default false
  solve literals more than once in counterexample-based quantifier instantiation
+option cbqiInnermost --cbqi-innermost bool :read-write :default true
+ only process innermost quantified formulas in counterexample-based quantifier instantiation
+option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false
+ process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation
  
 # CEGQI for arithmetic
 option cbqiUseInfInt --cbqi-use-inf-int bool :read-write :default false
@@ -341,10 +348,6 @@ option cbqiNopt --cbqi-nopt bool :default true
   non-optimal bounds for counterexample-based quantifier instantiation
 option cbqiLitDepend --cbqi-lit-dep bool :default true
   dependency lemmas for quantifier alternation in counterexample-based quantifier instantiation
-option cbqiInnermost --cbqi-innermost bool :read-write :default true
- only process innermost quantified formulas in counterexample-based quantifier instantiation
-option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false
- process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation
  
 # CEGQI for EPR
 option quantEpr --quant-epr bool :default false :read-write
index 7c3ab71d8cf11395b55ed16454ee3b141010d134..0f204383a04d9b7dd56f16d6b75fdd9b395bc00c 100644 (file)
@@ -816,7 +816,7 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d,
             d_sampler.find(a);
         if (its == d_sampler.end())
         {
-          d_sampler[a].initialize(d_tds, a, options::sygusSamples());
+          d_sampler[a].initializeSygus(d_tds, a, options::sygusSamples());
           its = d_sampler.find(a);
         }
         Node sample_ret = its->second.registerTerm(bv);
index cc00599d31da65d996973fcda3f3d4e5484c60c6..889a808791972f8e4b416229c9715c292799fe0e 100644 (file)
@@ -112,6 +112,15 @@ void CegConjecture::assign( Node q ) {
   d_base_inst = Rewriter::rewrite(d_qe->getInstantiate()->getInstantiation(
       d_embed_quant, vars, d_candidates));
   Trace("cegqi") << "Base instantiation is :      " << d_base_inst << std::endl;
+  d_base_body = d_base_inst;
+  if (d_base_body.getKind() == NOT && d_base_body[0].getKind() == FORALL)
+  {
+    for (const Node& v : d_base_body[0][0])
+    {
+      d_base_vars.push_back(v);
+    }
+    d_base_body = d_base_body[0][1];
+  }
 
   // register this term with sygus database and other utilities that impact
   // the enumerative sygus search
@@ -182,7 +191,16 @@ void CegConjecture::assign( Node q ) {
     Trace("cegqi-lemma") << "Cegqi::Lemma : initial (guarded) lemma : " << lem << std::endl;
     d_qe->getOutputChannel().lemma( lem );
   }
-  
+
+  // assign the cegis sampler if applicable
+  if (options::cegisSample() != CEGIS_SAMPLE_NONE)
+  {
+    Trace("cegis-sample") << "Initialize sampler for " << d_base_body << "..."
+                          << std::endl;
+    TypeNode bt = d_base_body.getType();
+    d_cegis_sampler.initialize(bt, d_base_vars, options::sygusSamples());
+  }
+
   Trace("cegqi") << "...finished, single invocation = " << isSingleInvocation() << std::endl;
 }
 
@@ -284,6 +302,18 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode
   //check whether we will run CEGIS on inner skolem variables
   bool sk_refine = ( !isGround() || d_refine_count==0 ) && ( !d_ceg_pbe->isPbe() || constructed_cand );
   if( sk_refine ){
+    if (options::cegisSample() == CEGIS_SAMPLE_TRUST)
+    {
+      // we have that the current candidate passed a sample test
+      // since we trust sampling in this mode, we assert there is no
+      // counterexample to the conjecture here.
+      NodeManager* nm = NodeManager::currentNM();
+      Node lem = nm->mkNode(OR, d_quant.negate(), nm->mkConst(false));
+      lem = getStreamGuardedLemma(lem);
+      lems.push_back(lem);
+      recordInstantiation(c_model_values);
+      return;
+    }
     Assert( d_ce_sk.empty() );
     d_ce_sk.push_back( std::vector< Node >() );
   }else{
@@ -329,12 +359,7 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode
       std::map< Node, Node > visited_n;
       lem = d_qe->getTermDatabaseSygus()->getEagerUnfold( lem, visited_n );
     }
-    if( options::sygusStream() ){
-      // if we are in streaming mode, we guard with the current stream guard
-      Node curr_stream_guard = getCurrentStreamGuard();
-      Assert( !curr_stream_guard.isNull() );
-      lem = NodeManager::currentNM()->mkNode( kind::OR, curr_stream_guard.negate(), lem );
-    }
+    lem = getStreamGuardedLemma(lem);
     lems.push_back( lem );
     recordInstantiation( c_model_values );
   }
@@ -404,17 +429,13 @@ void CegConjecture::doRefine( std::vector< Node >& lems ){
   
   Trace("cegqi-refine") << "doRefine : construct and finalize lemmas..." << std::endl;
   
-  Node lem = base_lem;
   
   base_lem = base_lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() );
   base_lem = Rewriter::rewrite( base_lem );
-  d_refinement_lemmas_base.push_back( base_lem );
-  
-  lem = NodeManager::currentNM()->mkNode( OR, getGuard().negate(), lem );
-  
-  lem = lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() );
-  lem = Rewriter::rewrite( lem );
-  d_refinement_lemmas.push_back( lem );
+  d_refinement_lemmas.push_back(base_lem);
+
+  Node lem =
+      NodeManager::currentNM()->mkNode(OR, getGuard().negate(), base_lem);
   lems.push_back( lem );
 
   d_ce_sk.clear();
@@ -473,6 +494,18 @@ Node CegConjecture::getCurrentStreamGuard() const {
   }
 }
 
+Node CegConjecture::getStreamGuardedLemma(Node n) const
+{
+  if (options::sygusStream())
+  {
+    // if we are in streaming mode, we guard with the current stream guard
+    Node csg = getCurrentStreamGuard();
+    Assert(!csg.isNull());
+    return NodeManager::currentNM()->mkNode(kind::OR, csg.negate(), n);
+  }
+  return n;
+}
+
 Node CegConjecture::getNextDecisionRequest( unsigned& priority ) {
   // first, must try the guard
   // which denotes "this conjecture is feasible"
@@ -596,7 +629,8 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
         std::map<Node, SygusSampler>::iterator its = d_sampler.find(prog);
         if (its == d_sampler.end())
         {
-          d_sampler[prog].initialize(sygusDb, prog, options::sygusSamples());
+          d_sampler[prog].initializeSygus(
+              sygusDb, prog, options::sygusSamples());
           its = d_sampler.find(prog);
         }
         Node solb = sygusDb->sygusToBuiltin(sol, prog.getType());
@@ -793,6 +827,83 @@ Node CegConjecture::getSymmetryBreakingPredicate(
   }
 }
 
+bool CegConjecture::sampleAddRefinementLemma(std::vector<Node>& vals,
+                                             std::vector<Node>& lems)
+{
+  if (Trace.isOn("cegis-sample"))
+  {
+    Trace("cegis-sample") << "Check sampling for candidate solution"
+                          << std::endl;
+    for (unsigned i = 0, size = vals.size(); i < size; i++)
+    {
+      Trace("cegis-sample")
+          << "  " << d_candidates[i] << " -> " << vals[i] << std::endl;
+    }
+  }
+  Assert(vals.size() == d_candidates.size());
+  Node sbody = d_base_body.substitute(
+      d_candidates.begin(), d_candidates.end(), vals.begin(), vals.end());
+  Trace("cegis-sample-debug") << "Sample " << sbody << std::endl;
+  // do eager unfolding
+  std::map<Node, Node> visited_n;
+  sbody = d_qe->getTermDatabaseSygus()->getEagerUnfold(sbody, visited_n);
+  Trace("cegis-sample") << "Sample (after unfolding): " << sbody << std::endl;
+
+  NodeManager* nm = NodeManager::currentNM();
+  for (unsigned i = 0, size = d_cegis_sampler.getNumSamplePoints(); i < size;
+       i++)
+  {
+    if (d_cegis_sample_refine.find(i) == d_cegis_sample_refine.end())
+    {
+      Node ev = d_cegis_sampler.evaluate(sbody, i);
+      Trace("cegis-sample-debug")
+          << "...evaluate point #" << i << " to " << ev << std::endl;
+      Assert(ev.isConst());
+      Assert(ev.getType().isBoolean());
+      if (!ev.getConst<bool>())
+      {
+        Trace("cegis-sample-debug") << "...false for point #" << i << std::endl;
+        // mark this as a CEGIS point (no longer sampled)
+        d_cegis_sample_refine.insert(i);
+        std::vector<Node> pt;
+        d_cegis_sampler.getSamplePoint(i, pt);
+        Assert(d_base_vars.size() == pt.size());
+        Node rlem = d_base_body.substitute(
+            d_base_vars.begin(), d_base_vars.end(), pt.begin(), pt.end());
+        rlem = Rewriter::rewrite(rlem);
+        if (std::find(
+                d_refinement_lemmas.begin(), d_refinement_lemmas.end(), rlem)
+            == d_refinement_lemmas.end())
+        {
+          if (Trace.isOn("cegis-sample"))
+          {
+            Trace("cegis-sample") << "   false for point #" << i << " : ";
+            for (const Node& cn : pt)
+            {
+              Trace("cegis-sample") << cn << " ";
+            }
+            Trace("cegis-sample") << std::endl;
+          }
+          Trace("cegqi-engine") << "  *** Refine by sampling" << std::endl;
+          d_refinement_lemmas.push_back(rlem);
+          // if trust, we are not interested in sending out refinement lemmas
+          if (options::cegisSample() != CEGIS_SAMPLE_TRUST)
+          {
+            Node lem = nm->mkNode(OR, getGuard().negate(), rlem);
+            lems.push_back(lem);
+          }
+          return true;
+        }
+        else
+        {
+          Trace("cegis-sample-debug") << "...duplicate." << std::endl;
+        }
+      }
+    }
+  }
+  return false;
+}
+
 }/* namespace CVC4::theory::quantifiers */
 }/* namespace CVC4::theory */
 }/* namespace CVC4 */
index 011967ca1c22ae746e213250777e386b55965439..dae261111cae91689e6558dd171f53547c40c6c1 100644 (file)
@@ -75,9 +75,6 @@ public:
   * This is step 2(b) of Figure 3 of Reynolds et al CAV 2015.
   */
   void doRefine(std::vector< Node >& lems);
-  /** Print the synthesis solution
-   * singleInvocation is whether the solution was found by single invocation techniques.
-   */
   //-------------------------------end for counterexample-guided check/refine
   /**
    * prints the synthesis solution to output stream out.
@@ -124,10 +121,21 @@ public:
   //-----------------------------------refinement lemmas
   /** get number of refinement lemmas we have added so far */
   unsigned getNumRefinementLemmas() { return d_refinement_lemmas.size(); }
-  /** get refinement lemma */
+  /** get refinement lemma
+   *
+   * If d_embed_quant is forall d. exists y. P( d, y ), then a refinement
+   * lemma is one of the form ~P( d_candidates, c ) for some c.
+   */
   Node getRefinementLemma( unsigned i ) { return d_refinement_lemmas[i]; }
-  /** get refinement lemma */
-  Node getRefinementBaseLemma( unsigned i ) { return d_refinement_lemmas_base[i]; }
+  /** sample add refinement lemma
+   *
+   * This function will check if there is a sample point in d_sampler that
+   * refutes the candidate solution (d_quant_vars->vals). If so, it adds a
+   * refinement lemma to the lists d_refinement_lemmas that corresponds to that
+   * sample point, and adds a lemma to lems if cegisSample mode is not trust.
+   */
+  bool sampleAddRefinementLemma(std::vector<Node>& vals,
+                                std::vector<Node>& lems);
   //-----------------------------------end refinement lemmas
 
   /** get program by examples utility */
@@ -151,14 +159,21 @@ private:
   /** grammar utility */
   std::unique_ptr<CegGrammarConstructor> d_ceg_gc;
   /** list of constants for quantified formula
-  * The Skolems for the negation of d_embed_quant.
+  * The outer Skolems for the negation of d_embed_quant.
   */
   std::vector< Node > d_candidates;
   /** base instantiation
   * If d_embed_quant is forall d. exists y. P( d, y ), then
-  * this is the formula  P( candidates, y ).
+  * this is the formula  exists y. P( d_candidates, y ).
   */
   Node d_base_inst;
+  /** If d_base_inst is exists y. P( d, y ), then this is y. */
+  std::vector<Node> d_base_vars;
+  /**
+   * If d_base_inst is exists y. P( d, y ), then this is the formula
+   * P( d_candidates, y ).
+   */
+  Node d_base_body;
   /** expand base inst to disjuncts */
   std::vector< Node > d_base_disj;
   /** list of variables on inner quantification */
@@ -170,14 +185,13 @@ private:
   //-----------------------------------refinement lemmas
   /** refinement lemmas */
   std::vector< Node > d_refinement_lemmas;
-  std::vector< Node > d_refinement_lemmas_base;
   //-----------------------------------end refinement lemmas
 
-  /** quantified formula asserted */
+  /** the asserted (negated) conjecture */
   Node d_quant;
-  /** quantified formula (after simplification) */
+  /** (negated) conjecture after simplification */
   Node d_simp_quant;
-  /** quantified formula (after simplification, conversion to deep embedding) */
+  /** (negated) conjecture after simplification, conversion to deep embedding */
   Node d_embed_quant;
   /** candidate information */
   class CandidateInfo {
@@ -227,6 +241,12 @@ private:
   std::vector< Node > d_stream_guards;
   /** get current stream guard */
   Node getCurrentStreamGuard() const;
+  /** get stream guarded lemma
+   *
+   * If sygusStream is enabled, this returns ( G V n ) where G is the guard
+   * returned by getCurrentStreamGuard, otherwise this returns n.
+   */
+  Node getStreamGuardedLemma(Node n) const;
   //-------------------------------- end sygus stream
   //-------------------------------- non-syntax guided (deprecated)
   /** Whether we are syntax-guided (e.g. was the input in SyGuS format).
@@ -242,6 +262,18 @@ private:
    * rewrite rules.
    */
   std::map<Node, SygusSampler> d_sampler;
+  /** sampler object for the option cegisSample()
+   *
+   * This samples points of the type of the inner variables of the synthesis
+   * conjecture (d_base_vars).
+   */
+  SygusSampler d_cegis_sampler;
+  /** cegis sample refine points
+   *
+   * Stores the list of indices of sample points in d_cegis_sampler we have
+   * added as refinement lemmas.
+   */
+  std::unordered_set<unsigned> d_cegis_sample_refine;
 };
 
 } /* namespace CVC4::theory::quantifiers */
index dc359d252676a2dd7939ce55f674cb56edc3d318..38cfb9ba701439da7ccdd7bfa8afebc32141d8ec 100644 (file)
@@ -238,17 +238,33 @@ void CegInstantiation::checkCegConjecture( CegConjecture * conj ) {
 
 void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vector< Node >& vs, std::vector< Node >& ms, std::vector< Node >& lems ) {
   Trace("sygus-cref-eval") << "Cref eval : conjecture has " << conj->getNumRefinementLemmas() << " refinement lemmas." << std::endl;
-  if( conj->getNumRefinementLemmas()>0 ){
+  unsigned nlemmas = conj->getNumRefinementLemmas();
+  if (nlemmas > 0 || options::cegisSample() != CEGIS_SAMPLE_NONE)
+  {
     Assert( vs.size()==ms.size() );
 
     TermDbSygus* tds = d_quantEngine->getTermDatabaseSygus();
     Node nfalse = d_quantEngine->getTermUtil()->d_false;
     Node neg_guard = conj->getGuard().negate();
-    for( unsigned i=0; i<conj->getNumRefinementLemmas(); i++ ){
+    for (unsigned i = 0; i <= nlemmas; i++)
+    {
+      if (i == nlemmas)
+      {
+        bool addedSample = false;
+        // find a new one by sampling, if applicable
+        if (options::cegisSample() != CEGIS_SAMPLE_NONE)
+        {
+          addedSample = conj->sampleAddRefinementLemma(ms, lems);
+        }
+        if (!addedSample)
+        {
+          return;
+        }
+      }
       Node lem;
       std::map< Node, Node > visited;
       std::map< Node, std::vector< Node > > exp;
-      lem = conj->getRefinementBaseLemma( i );
+      lem = conj->getRefinementLemma(i);
       if( !lem.isNull() ){
         std::vector< Node > lem_conj;
         //break into conjunctions
index b5e63a6abd3494c55b2b43fbb93fb113f7ee7b7e..0b8f390f3c69fea2493d58dd523efc5905ff25e3 100644 (file)
@@ -65,7 +65,25 @@ Node LazyTrie::add(Node n,
 
 SygusSampler::SygusSampler() : d_tds(nullptr), d_is_valid(false) {}
 
-void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples)
+void SygusSampler::initialize(TypeNode tn,
+                              std::vector<Node>& vars,
+                              unsigned nsamples)
+{
+  d_tds = nullptr;
+  d_is_valid = true;
+  d_tn = tn;
+  d_ftn = TypeNode::null();
+  d_vars.insert(d_vars.end(), vars.begin(), vars.end());
+  for (const Node& sv : vars)
+  {
+    TypeNode svt = sv.getType();
+    d_var_index[sv] = d_type_vars[svt].size();
+    d_type_vars[svt].push_back(sv);
+  }
+  initializeSamples(nsamples);
+}
+
+void SygusSampler::initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples)
 {
   d_tds = tds;
   d_is_valid = true;
@@ -73,12 +91,12 @@ void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples)
   Assert(d_ftn.isDatatype());
   const Datatype& dt = static_cast<DatatypeType>(d_ftn.toType()).getDatatype();
   Assert(dt.isSygus());
+  d_tn = TypeNode::fromType(dt.getSygusType());
 
   Trace("sygus-sample") << "Register sampler for " << f << std::endl;
 
   d_var_index.clear();
   d_type_vars.clear();
-  std::vector<TypeNode> types;
   // get the sygus variable list
   Node var_list = Node::fromExpr(dt.getSygusVarList());
   if (!var_list.isNull())
@@ -87,14 +105,24 @@ void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples)
     {
       TypeNode svt = sv.getType();
       d_var_index[sv] = d_type_vars[svt].size();
+      d_vars.push_back(sv);
       d_type_vars[svt].push_back(sv);
-      types.push_back(svt);
-      Trace("sygus-sample") << "  var #" << types.size() << " : " << sv << " : "
-                            << svt << std::endl;
     }
   }
+  initializeSamples(nsamples);
+}
 
+void SygusSampler::initializeSamples(unsigned nsamples)
+{
   d_samples.clear();
+  std::vector<TypeNode> types;
+  for (const Node& v : d_vars)
+  {
+    TypeNode vt = v.getType();
+    types.push_back(vt);
+    Trace("sygus-sample") << "  var #" << types.size() << " : " << v << " : "
+                          << vt << std::endl;
+  }
   for (unsigned i = 0; i < nsamples; i++)
   {
     std::vector<Node> sample_pt;
@@ -121,6 +149,7 @@ Node SygusSampler::registerTerm(Node n, bool forceKeep)
 {
   if (d_is_valid)
   {
+    Assert(n.getType() == d_tn);
     return d_trie.add(n, this, 0, d_samples.size(), forceKeep);
   }
   return n;
@@ -254,10 +283,20 @@ bool SygusSampler::containsFreeVariables(Node a, Node b)
   return true;
 }
 
+void SygusSampler::getSamplePoint(unsigned index, std::vector<Node>& pt)
+{
+  Assert(index < d_samples.size());
+  std::vector<Node>& spt = d_samples[index];
+  pt.insert(pt.end(), spt.begin(), spt.end());
+}
+
 Node SygusSampler::evaluate(Node n, unsigned index)
 {
   Assert(index < d_samples.size());
-  Node ev = d_tds->evaluateBuiltin(d_ftn, n, d_samples[index]);
+  // just a substitution
+  std::vector<Node>& pt = d_samples[index];
+  Node ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end());
+  ev = Rewriter::rewrite(ev);
   Trace("sygus-sample-ev") << "( " << n << ", " << index << " ) -> " << ev
                            << std::endl;
   return ev;
index 8979316493f5c596bec0294ffcef8d5e7f262273..09f4124fe689bc5954931017d3e310ac4353710d 100644 (file)
@@ -137,12 +137,19 @@ class SygusSampler : public LazyTrieEvaluator
   virtual ~SygusSampler() {}
   /** initialize
    *
-   * tds : reference to a sygus database,
+   * tn : the return type of terms we will be testing with this class
+   * vars : the variables we are testing substitutions for
+   * nsamples : number of sample points this class will test.
+   */
+  void initialize(TypeNode tn, std::vector<Node>& vars, unsigned nsamples);
+  /** initialize sygus
+   *
+   * tds : pointer to sygus database,
    * f : a term of some SyGuS datatype type whose (builtin) values we will be
-   * testing,
+   * testing under the free variables in the grammar of f,
    * nsamples : number of sample points this class will test.
    */
-  void initialize(TermDbSygus* tds, Node f, unsigned nsamples);
+  void initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples);
   /** register term n with this sampler database
    *
    * forceKeep is whether we wish to force that n is chosen as a representative
@@ -172,6 +179,13 @@ class SygusSampler : public LazyTrieEvaluator
    * are those that occur in the range d_type_vars.
    */
   bool containsFreeVariables(Node a, Node b);
+  /** get number of sample points */
+  unsigned getNumSamplePoints() const { return d_samples.size(); }
+  /** get sample point
+   *
+   * Appends sample point #index to the vector pt.
+   */
+  void getSamplePoint(unsigned index, std::vector<Node>& pt);
   /** evaluate n on sample point index */
   Node evaluate(Node n, unsigned index);
 
@@ -181,7 +195,11 @@ class SygusSampler : public LazyTrieEvaluator
   /** samples */
   std::vector<std::vector<Node> > d_samples;
   /** type of nodes we will be registering with this class */
+  TypeNode d_tn;
+  /** the sygus type for this sampler (if applicable). */
   TypeNode d_ftn;
+  /** all variables */
+  std::vector<Node> d_vars;
   /** type variables
    *
    * For each type, a list of variables in the grammar we are considering, for
@@ -213,6 +231,11 @@ class SygusSampler : public LazyTrieEvaluator
    * store these in the vector fvs.
    */
   void computeFreeVariables(Node n, std::vector<Node>& fvs);
+  /** initialize samples
+   *
+   * Adds nsamples sample points to d_samples.
+   */
+  void initializeSamples(unsigned nsamples);
   /** get random value for a type
    *
    * Returns a random value for the given type based on the random number