From b72de87fb2804325137352ce79a6044d1b805576 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 2 Feb 2018 21:04:49 -0600 Subject: [PATCH] Option to use sampling for CEGIS (#1555) --- src/options/options_handler.cpp | 47 ++++++ src/options/options_handler.h | 3 + src/options/quantifiers_modes.h | 10 ++ src/options/quantifiers_options | 11 +- src/theory/datatypes/datatypes_sygus.cpp | 2 +- .../quantifiers/ce_guided_conjecture.cpp | 143 ++++++++++++++++-- src/theory/quantifiers/ce_guided_conjecture.h | 56 +++++-- .../quantifiers/ce_guided_instantiation.cpp | 22 ++- src/theory/quantifiers/sygus_sampler.cpp | 51 ++++++- src/theory/quantifiers/sygus_sampler.h | 29 +++- 10 files changed, 329 insertions(+), 45 deletions(-) diff --git a/src/options/options_handler.cpp b/src/options/options_handler.cpp index c29cfc4d2..61f7646ee 100644 --- a/src/options/options_handler.cpp +++ b/src/options/options_handler.cpp @@ -492,6 +492,25 @@ all \n\ \n\ "; +const std::string OptionsHandler::s_cegisSampleHelp = + "\ +Modes for sampling with counterexample-guided inductive synthesis (CEGIS),\ +supported by --cegis-sample:\n\ +\n\ +none (default) \n\ ++ Do not use sampling with CEGIS.\n\ +\n\ +use \n\ ++ Use sampling to accelerate CEGIS. This will rule out solutions for a\ + conjecture when they are not satisfied by a sample point.\n\ +\n\ +trust \n\ ++ Trust that when a solution for a conjecture is always true under sampling,\ + then it is indeed a solution. Note this option may print out spurious\ + solutions for synthesis conjectures.\n\ +\n\ +"; + const std::string OptionsHandler::s_sygusInvTemplHelp = "\ Template modes for sygus invariant synthesis, supported by --sygus-inv-templ:\n\ \n\ @@ -877,6 +896,34 @@ OptionsHandler::stringToCegqiSingleInvMode(std::string option, } } +theory::quantifiers::CegisSampleMode OptionsHandler::stringToCegisSampleMode( + std::string option, std::string optarg) +{ + if (optarg == "none") + { + return theory::quantifiers::CEGIS_SAMPLE_NONE; + } + else if (optarg == "use") + { + return theory::quantifiers::CEGIS_SAMPLE_USE; + } + else if (optarg == "trust") + { + return theory::quantifiers::CEGIS_SAMPLE_TRUST; + } + else if (optarg == "help") + { + puts(s_cegisSampleHelp.c_str()); + exit(1); + } + else + { + throw OptionException(std::string("unknown option for --cegis-sample: `") + + optarg + + "'. Try --cegis-sample help."); + } +} + theory::quantifiers::SygusInvTemplMode OptionsHandler::stringToSygusInvTemplMode(std::string option, std::string optarg) diff --git a/src/options/options_handler.h b/src/options/options_handler.h index e7bd87ebd..304009a98 100644 --- a/src/options/options_handler.h +++ b/src/options/options_handler.h @@ -108,6 +108,8 @@ public: std::string option, std::string optarg); theory::quantifiers::CegqiSingleInvMode stringToCegqiSingleInvMode( std::string option, std::string optarg); + theory::quantifiers::CegisSampleMode stringToCegisSampleMode( + std::string option, std::string optarg); theory::quantifiers::SygusInvTemplMode stringToSygusInvTemplMode( std::string option, std::string optarg); theory::quantifiers::MacrosQuantMode stringToMacrosQuantMode( @@ -243,6 +245,7 @@ public: static const std::string s_sygusSolutionOutModeHelp; static const std::string s_cbqiBvIneqModeHelp; static const std::string s_cegqiSingleInvHelp; + static const std::string s_cegisSampleHelp; static const std::string s_sygusInvTemplHelp; static const std::string s_termDbModeHelp; static const std::string s_theoryOfModeHelp; diff --git a/src/options/quantifiers_modes.h b/src/options/quantifiers_modes.h index 6274269ce..91fab54ff 100644 --- a/src/options/quantifiers_modes.h +++ b/src/options/quantifiers_modes.h @@ -216,6 +216,16 @@ enum CegqiSingleInvMode { CEGQI_SI_MODE_ALL, }; +enum CegisSampleMode +{ + /** do not use samples for CEGIS */ + CEGIS_SAMPLE_NONE, + /** use samples for CEGIS */ + CEGIS_SAMPLE_USE, + /** trust samples for CEGQI */ + CEGIS_SAMPLE_TRUST, +}; + enum SygusInvTemplMode { /** synthesize I( x ) */ SYGUS_INV_TEMPL_MODE_NONE, diff --git a/src/options/quantifiers_options b/src/options/quantifiers_options index 96d73feeb..34af81033 100644 --- a/src/options/quantifiers_options +++ b/src/options/quantifiers_options @@ -297,6 +297,9 @@ option sygusCRefEvalMinExp --sygus-cref-eval-min-exp bool :default true option sygusStream --sygus-stream bool :read-write :default false enumerate a stream of solutions instead of terminating after the first one + +option cegisSample --cegis-sample=MODE CVC4::theory::quantifiers::CegisSampleMode :read-write :default CVC4::theory::quantifiers::CEGIS_SAMPLE_NONE :include "options/quantifiers_modes.h" :handler stringToCegisSampleMode + mode for using samples in the counterexample-guided inductive synthesis loop # internal uses of sygus option sygusRewSynth --sygus-rr-synth bool :default false @@ -323,6 +326,10 @@ option cbqiMultiInst --cbqi-multi-inst bool :read-write :default false when applicable, do multi instantiations per quantifier per round in counterexample-based quantifier instantiation option cbqiRepeatLit --cbqi-repeat-lit bool :read-write :default false solve literals more than once in counterexample-based quantifier instantiation +option cbqiInnermost --cbqi-innermost bool :read-write :default true + only process innermost quantified formulas in counterexample-based quantifier instantiation +option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false + process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation # CEGQI for arithmetic option cbqiUseInfInt --cbqi-use-inf-int bool :read-write :default false @@ -341,10 +348,6 @@ option cbqiNopt --cbqi-nopt bool :default true non-optimal bounds for counterexample-based quantifier instantiation option cbqiLitDepend --cbqi-lit-dep bool :default true dependency lemmas for quantifier alternation in counterexample-based quantifier instantiation -option cbqiInnermost --cbqi-innermost bool :read-write :default true - only process innermost quantified formulas in counterexample-based quantifier instantiation -option cbqiNestedQE --cbqi-nested-qe bool :read-write :default false - process nested quantified formulas with quantifier elimination in counterexample-based quantifier instantiation # CEGQI for EPR option quantEpr --quant-epr bool :default false :read-write diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index 7c3ab71d8..0f204383a 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -816,7 +816,7 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, d_sampler.find(a); if (its == d_sampler.end()) { - d_sampler[a].initialize(d_tds, a, options::sygusSamples()); + d_sampler[a].initializeSygus(d_tds, a, options::sygusSamples()); its = d_sampler.find(a); } Node sample_ret = its->second.registerTerm(bv); diff --git a/src/theory/quantifiers/ce_guided_conjecture.cpp b/src/theory/quantifiers/ce_guided_conjecture.cpp index cc00599d3..889a80879 100644 --- a/src/theory/quantifiers/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/ce_guided_conjecture.cpp @@ -112,6 +112,15 @@ void CegConjecture::assign( Node q ) { d_base_inst = Rewriter::rewrite(d_qe->getInstantiate()->getInstantiation( d_embed_quant, vars, d_candidates)); Trace("cegqi") << "Base instantiation is : " << d_base_inst << std::endl; + d_base_body = d_base_inst; + if (d_base_body.getKind() == NOT && d_base_body[0].getKind() == FORALL) + { + for (const Node& v : d_base_body[0][0]) + { + d_base_vars.push_back(v); + } + d_base_body = d_base_body[0][1]; + } // register this term with sygus database and other utilities that impact // the enumerative sygus search @@ -182,7 +191,16 @@ void CegConjecture::assign( Node q ) { Trace("cegqi-lemma") << "Cegqi::Lemma : initial (guarded) lemma : " << lem << std::endl; d_qe->getOutputChannel().lemma( lem ); } - + + // assign the cegis sampler if applicable + if (options::cegisSample() != CEGIS_SAMPLE_NONE) + { + Trace("cegis-sample") << "Initialize sampler for " << d_base_body << "..." + << std::endl; + TypeNode bt = d_base_body.getType(); + d_cegis_sampler.initialize(bt, d_base_vars, options::sygusSamples()); + } + Trace("cegqi") << "...finished, single invocation = " << isSingleInvocation() << std::endl; } @@ -284,6 +302,18 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode //check whether we will run CEGIS on inner skolem variables bool sk_refine = ( !isGround() || d_refine_count==0 ) && ( !d_ceg_pbe->isPbe() || constructed_cand ); if( sk_refine ){ + if (options::cegisSample() == CEGIS_SAMPLE_TRUST) + { + // we have that the current candidate passed a sample test + // since we trust sampling in this mode, we assert there is no + // counterexample to the conjecture here. + NodeManager* nm = NodeManager::currentNM(); + Node lem = nm->mkNode(OR, d_quant.negate(), nm->mkConst(false)); + lem = getStreamGuardedLemma(lem); + lems.push_back(lem); + recordInstantiation(c_model_values); + return; + } Assert( d_ce_sk.empty() ); d_ce_sk.push_back( std::vector< Node >() ); }else{ @@ -329,12 +359,7 @@ void CegConjecture::doCheck(std::vector< Node >& lems, std::vector< Node >& mode std::map< Node, Node > visited_n; lem = d_qe->getTermDatabaseSygus()->getEagerUnfold( lem, visited_n ); } - if( options::sygusStream() ){ - // if we are in streaming mode, we guard with the current stream guard - Node curr_stream_guard = getCurrentStreamGuard(); - Assert( !curr_stream_guard.isNull() ); - lem = NodeManager::currentNM()->mkNode( kind::OR, curr_stream_guard.negate(), lem ); - } + lem = getStreamGuardedLemma(lem); lems.push_back( lem ); recordInstantiation( c_model_values ); } @@ -404,17 +429,13 @@ void CegConjecture::doRefine( std::vector< Node >& lems ){ Trace("cegqi-refine") << "doRefine : construct and finalize lemmas..." << std::endl; - Node lem = base_lem; base_lem = base_lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() ); base_lem = Rewriter::rewrite( base_lem ); - d_refinement_lemmas_base.push_back( base_lem ); - - lem = NodeManager::currentNM()->mkNode( OR, getGuard().negate(), lem ); - - lem = lem.substitute( sk_vars.begin(), sk_vars.end(), sk_subs.begin(), sk_subs.end() ); - lem = Rewriter::rewrite( lem ); - d_refinement_lemmas.push_back( lem ); + d_refinement_lemmas.push_back(base_lem); + + Node lem = + NodeManager::currentNM()->mkNode(OR, getGuard().negate(), base_lem); lems.push_back( lem ); d_ce_sk.clear(); @@ -473,6 +494,18 @@ Node CegConjecture::getCurrentStreamGuard() const { } } +Node CegConjecture::getStreamGuardedLemma(Node n) const +{ + if (options::sygusStream()) + { + // if we are in streaming mode, we guard with the current stream guard + Node csg = getCurrentStreamGuard(); + Assert(!csg.isNull()); + return NodeManager::currentNM()->mkNode(kind::OR, csg.negate(), n); + } + return n; +} + Node CegConjecture::getNextDecisionRequest( unsigned& priority ) { // first, must try the guard // which denotes "this conjecture is feasible" @@ -596,7 +629,8 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation std::map::iterator its = d_sampler.find(prog); if (its == d_sampler.end()) { - d_sampler[prog].initialize(sygusDb, prog, options::sygusSamples()); + d_sampler[prog].initializeSygus( + sygusDb, prog, options::sygusSamples()); its = d_sampler.find(prog); } Node solb = sygusDb->sygusToBuiltin(sol, prog.getType()); @@ -793,6 +827,83 @@ Node CegConjecture::getSymmetryBreakingPredicate( } } +bool CegConjecture::sampleAddRefinementLemma(std::vector& vals, + std::vector& lems) +{ + if (Trace.isOn("cegis-sample")) + { + Trace("cegis-sample") << "Check sampling for candidate solution" + << std::endl; + for (unsigned i = 0, size = vals.size(); i < size; i++) + { + Trace("cegis-sample") + << " " << d_candidates[i] << " -> " << vals[i] << std::endl; + } + } + Assert(vals.size() == d_candidates.size()); + Node sbody = d_base_body.substitute( + d_candidates.begin(), d_candidates.end(), vals.begin(), vals.end()); + Trace("cegis-sample-debug") << "Sample " << sbody << std::endl; + // do eager unfolding + std::map visited_n; + sbody = d_qe->getTermDatabaseSygus()->getEagerUnfold(sbody, visited_n); + Trace("cegis-sample") << "Sample (after unfolding): " << sbody << std::endl; + + NodeManager* nm = NodeManager::currentNM(); + for (unsigned i = 0, size = d_cegis_sampler.getNumSamplePoints(); i < size; + i++) + { + if (d_cegis_sample_refine.find(i) == d_cegis_sample_refine.end()) + { + Node ev = d_cegis_sampler.evaluate(sbody, i); + Trace("cegis-sample-debug") + << "...evaluate point #" << i << " to " << ev << std::endl; + Assert(ev.isConst()); + Assert(ev.getType().isBoolean()); + if (!ev.getConst()) + { + Trace("cegis-sample-debug") << "...false for point #" << i << std::endl; + // mark this as a CEGIS point (no longer sampled) + d_cegis_sample_refine.insert(i); + std::vector pt; + d_cegis_sampler.getSamplePoint(i, pt); + Assert(d_base_vars.size() == pt.size()); + Node rlem = d_base_body.substitute( + d_base_vars.begin(), d_base_vars.end(), pt.begin(), pt.end()); + rlem = Rewriter::rewrite(rlem); + if (std::find( + d_refinement_lemmas.begin(), d_refinement_lemmas.end(), rlem) + == d_refinement_lemmas.end()) + { + if (Trace.isOn("cegis-sample")) + { + Trace("cegis-sample") << " false for point #" << i << " : "; + for (const Node& cn : pt) + { + Trace("cegis-sample") << cn << " "; + } + Trace("cegis-sample") << std::endl; + } + Trace("cegqi-engine") << " *** Refine by sampling" << std::endl; + d_refinement_lemmas.push_back(rlem); + // if trust, we are not interested in sending out refinement lemmas + if (options::cegisSample() != CEGIS_SAMPLE_TRUST) + { + Node lem = nm->mkNode(OR, getGuard().negate(), rlem); + lems.push_back(lem); + } + return true; + } + else + { + Trace("cegis-sample-debug") << "...duplicate." << std::endl; + } + } + } + } + return false; +} + }/* namespace CVC4::theory::quantifiers */ }/* namespace CVC4::theory */ }/* namespace CVC4 */ diff --git a/src/theory/quantifiers/ce_guided_conjecture.h b/src/theory/quantifiers/ce_guided_conjecture.h index 011967ca1..dae261111 100644 --- a/src/theory/quantifiers/ce_guided_conjecture.h +++ b/src/theory/quantifiers/ce_guided_conjecture.h @@ -75,9 +75,6 @@ public: * This is step 2(b) of Figure 3 of Reynolds et al CAV 2015. */ void doRefine(std::vector< Node >& lems); - /** Print the synthesis solution - * singleInvocation is whether the solution was found by single invocation techniques. - */ //-------------------------------end for counterexample-guided check/refine /** * prints the synthesis solution to output stream out. @@ -124,10 +121,21 @@ public: //-----------------------------------refinement lemmas /** get number of refinement lemmas we have added so far */ unsigned getNumRefinementLemmas() { return d_refinement_lemmas.size(); } - /** get refinement lemma */ + /** get refinement lemma + * + * If d_embed_quant is forall d. exists y. P( d, y ), then a refinement + * lemma is one of the form ~P( d_candidates, c ) for some c. + */ Node getRefinementLemma( unsigned i ) { return d_refinement_lemmas[i]; } - /** get refinement lemma */ - Node getRefinementBaseLemma( unsigned i ) { return d_refinement_lemmas_base[i]; } + /** sample add refinement lemma + * + * This function will check if there is a sample point in d_sampler that + * refutes the candidate solution (d_quant_vars->vals). If so, it adds a + * refinement lemma to the lists d_refinement_lemmas that corresponds to that + * sample point, and adds a lemma to lems if cegisSample mode is not trust. + */ + bool sampleAddRefinementLemma(std::vector& vals, + std::vector& lems); //-----------------------------------end refinement lemmas /** get program by examples utility */ @@ -151,14 +159,21 @@ private: /** grammar utility */ std::unique_ptr d_ceg_gc; /** list of constants for quantified formula - * The Skolems for the negation of d_embed_quant. + * The outer Skolems for the negation of d_embed_quant. */ std::vector< Node > d_candidates; /** base instantiation * If d_embed_quant is forall d. exists y. P( d, y ), then - * this is the formula P( candidates, y ). + * this is the formula exists y. P( d_candidates, y ). */ Node d_base_inst; + /** If d_base_inst is exists y. P( d, y ), then this is y. */ + std::vector d_base_vars; + /** + * If d_base_inst is exists y. P( d, y ), then this is the formula + * P( d_candidates, y ). + */ + Node d_base_body; /** expand base inst to disjuncts */ std::vector< Node > d_base_disj; /** list of variables on inner quantification */ @@ -170,14 +185,13 @@ private: //-----------------------------------refinement lemmas /** refinement lemmas */ std::vector< Node > d_refinement_lemmas; - std::vector< Node > d_refinement_lemmas_base; //-----------------------------------end refinement lemmas - /** quantified formula asserted */ + /** the asserted (negated) conjecture */ Node d_quant; - /** quantified formula (after simplification) */ + /** (negated) conjecture after simplification */ Node d_simp_quant; - /** quantified formula (after simplification, conversion to deep embedding) */ + /** (negated) conjecture after simplification, conversion to deep embedding */ Node d_embed_quant; /** candidate information */ class CandidateInfo { @@ -227,6 +241,12 @@ private: std::vector< Node > d_stream_guards; /** get current stream guard */ Node getCurrentStreamGuard() const; + /** get stream guarded lemma + * + * If sygusStream is enabled, this returns ( G V n ) where G is the guard + * returned by getCurrentStreamGuard, otherwise this returns n. + */ + Node getStreamGuardedLemma(Node n) const; //-------------------------------- end sygus stream //-------------------------------- non-syntax guided (deprecated) /** Whether we are syntax-guided (e.g. was the input in SyGuS format). @@ -242,6 +262,18 @@ private: * rewrite rules. */ std::map d_sampler; + /** sampler object for the option cegisSample() + * + * This samples points of the type of the inner variables of the synthesis + * conjecture (d_base_vars). + */ + SygusSampler d_cegis_sampler; + /** cegis sample refine points + * + * Stores the list of indices of sample points in d_cegis_sampler we have + * added as refinement lemmas. + */ + std::unordered_set d_cegis_sample_refine; }; } /* namespace CVC4::theory::quantifiers */ diff --git a/src/theory/quantifiers/ce_guided_instantiation.cpp b/src/theory/quantifiers/ce_guided_instantiation.cpp index dc359d252..38cfb9ba7 100644 --- a/src/theory/quantifiers/ce_guided_instantiation.cpp +++ b/src/theory/quantifiers/ce_guided_instantiation.cpp @@ -238,17 +238,33 @@ void CegInstantiation::checkCegConjecture( CegConjecture * conj ) { void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vector< Node >& vs, std::vector< Node >& ms, std::vector< Node >& lems ) { Trace("sygus-cref-eval") << "Cref eval : conjecture has " << conj->getNumRefinementLemmas() << " refinement lemmas." << std::endl; - if( conj->getNumRefinementLemmas()>0 ){ + unsigned nlemmas = conj->getNumRefinementLemmas(); + if (nlemmas > 0 || options::cegisSample() != CEGIS_SAMPLE_NONE) + { Assert( vs.size()==ms.size() ); TermDbSygus* tds = d_quantEngine->getTermDatabaseSygus(); Node nfalse = d_quantEngine->getTermUtil()->d_false; Node neg_guard = conj->getGuard().negate(); - for( unsigned i=0; igetNumRefinementLemmas(); i++ ){ + for (unsigned i = 0; i <= nlemmas; i++) + { + if (i == nlemmas) + { + bool addedSample = false; + // find a new one by sampling, if applicable + if (options::cegisSample() != CEGIS_SAMPLE_NONE) + { + addedSample = conj->sampleAddRefinementLemma(ms, lems); + } + if (!addedSample) + { + return; + } + } Node lem; std::map< Node, Node > visited; std::map< Node, std::vector< Node > > exp; - lem = conj->getRefinementBaseLemma( i ); + lem = conj->getRefinementLemma(i); if( !lem.isNull() ){ std::vector< Node > lem_conj; //break into conjunctions diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index b5e63a6ab..0b8f390f3 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -65,7 +65,25 @@ Node LazyTrie::add(Node n, SygusSampler::SygusSampler() : d_tds(nullptr), d_is_valid(false) {} -void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples) +void SygusSampler::initialize(TypeNode tn, + std::vector& vars, + unsigned nsamples) +{ + d_tds = nullptr; + d_is_valid = true; + d_tn = tn; + d_ftn = TypeNode::null(); + d_vars.insert(d_vars.end(), vars.begin(), vars.end()); + for (const Node& sv : vars) + { + TypeNode svt = sv.getType(); + d_var_index[sv] = d_type_vars[svt].size(); + d_type_vars[svt].push_back(sv); + } + initializeSamples(nsamples); +} + +void SygusSampler::initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples) { d_tds = tds; d_is_valid = true; @@ -73,12 +91,12 @@ void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples) Assert(d_ftn.isDatatype()); const Datatype& dt = static_cast(d_ftn.toType()).getDatatype(); Assert(dt.isSygus()); + d_tn = TypeNode::fromType(dt.getSygusType()); Trace("sygus-sample") << "Register sampler for " << f << std::endl; d_var_index.clear(); d_type_vars.clear(); - std::vector types; // get the sygus variable list Node var_list = Node::fromExpr(dt.getSygusVarList()); if (!var_list.isNull()) @@ -87,14 +105,24 @@ void SygusSampler::initialize(TermDbSygus* tds, Node f, unsigned nsamples) { TypeNode svt = sv.getType(); d_var_index[sv] = d_type_vars[svt].size(); + d_vars.push_back(sv); d_type_vars[svt].push_back(sv); - types.push_back(svt); - Trace("sygus-sample") << " var #" << types.size() << " : " << sv << " : " - << svt << std::endl; } } + initializeSamples(nsamples); +} +void SygusSampler::initializeSamples(unsigned nsamples) +{ d_samples.clear(); + std::vector types; + for (const Node& v : d_vars) + { + TypeNode vt = v.getType(); + types.push_back(vt); + Trace("sygus-sample") << " var #" << types.size() << " : " << v << " : " + << vt << std::endl; + } for (unsigned i = 0; i < nsamples; i++) { std::vector sample_pt; @@ -121,6 +149,7 @@ Node SygusSampler::registerTerm(Node n, bool forceKeep) { if (d_is_valid) { + Assert(n.getType() == d_tn); return d_trie.add(n, this, 0, d_samples.size(), forceKeep); } return n; @@ -254,10 +283,20 @@ bool SygusSampler::containsFreeVariables(Node a, Node b) return true; } +void SygusSampler::getSamplePoint(unsigned index, std::vector& pt) +{ + Assert(index < d_samples.size()); + std::vector& spt = d_samples[index]; + pt.insert(pt.end(), spt.begin(), spt.end()); +} + Node SygusSampler::evaluate(Node n, unsigned index) { Assert(index < d_samples.size()); - Node ev = d_tds->evaluateBuiltin(d_ftn, n, d_samples[index]); + // just a substitution + std::vector& pt = d_samples[index]; + Node ev = n.substitute(d_vars.begin(), d_vars.end(), pt.begin(), pt.end()); + ev = Rewriter::rewrite(ev); Trace("sygus-sample-ev") << "( " << n << ", " << index << " ) -> " << ev << std::endl; return ev; diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index 897931649..09f4124fe 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -137,12 +137,19 @@ class SygusSampler : public LazyTrieEvaluator virtual ~SygusSampler() {} /** initialize * - * tds : reference to a sygus database, + * tn : the return type of terms we will be testing with this class + * vars : the variables we are testing substitutions for + * nsamples : number of sample points this class will test. + */ + void initialize(TypeNode tn, std::vector& vars, unsigned nsamples); + /** initialize sygus + * + * tds : pointer to sygus database, * f : a term of some SyGuS datatype type whose (builtin) values we will be - * testing, + * testing under the free variables in the grammar of f, * nsamples : number of sample points this class will test. */ - void initialize(TermDbSygus* tds, Node f, unsigned nsamples); + void initializeSygus(TermDbSygus* tds, Node f, unsigned nsamples); /** register term n with this sampler database * * forceKeep is whether we wish to force that n is chosen as a representative @@ -172,6 +179,13 @@ class SygusSampler : public LazyTrieEvaluator * are those that occur in the range d_type_vars. */ bool containsFreeVariables(Node a, Node b); + /** get number of sample points */ + unsigned getNumSamplePoints() const { return d_samples.size(); } + /** get sample point + * + * Appends sample point #index to the vector pt. + */ + void getSamplePoint(unsigned index, std::vector& pt); /** evaluate n on sample point index */ Node evaluate(Node n, unsigned index); @@ -181,7 +195,11 @@ class SygusSampler : public LazyTrieEvaluator /** samples */ std::vector > d_samples; /** type of nodes we will be registering with this class */ + TypeNode d_tn; + /** the sygus type for this sampler (if applicable). */ TypeNode d_ftn; + /** all variables */ + std::vector d_vars; /** type variables * * For each type, a list of variables in the grammar we are considering, for @@ -213,6 +231,11 @@ class SygusSampler : public LazyTrieEvaluator * store these in the vector fvs. */ void computeFreeVariables(Node n, std::vector& fvs); + /** initialize samples + * + * Adds nsamples sample points to d_samples. + */ + void initializeSamples(unsigned nsamples); /** get random value for a type * * Returns a random value for the given type based on the random number -- 2.30.2