From: Andrew Reynolds Date: Wed, 27 Jun 2018 19:12:17 +0000 (-0500) Subject: Synthesize candidate-rewrites from standard inputs (#1918) X-Git-Tag: cvc5-1.0.0~4940 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=d6c7967cfc7a9f8530f0de50f12f99bfc5f93da7;p=cvc5.git Synthesize candidate-rewrites from standard inputs (#1918) --- diff --git a/src/Makefile.am b/src/Makefile.am index b36c453e1..b81d93081 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -86,6 +86,8 @@ libcvc4_la_SOURCES = \ preprocessing/passes/symmetry_breaker.h \ preprocessing/passes/symmetry_detect.cpp \ preprocessing/passes/symmetry_detect.h \ + preprocessing/passes/synth_rew_rules.cpp \ + preprocessing/passes/synth_rew_rules.h \ preprocessing/preprocessing_pass.cpp \ preprocessing/preprocessing_pass.h \ preprocessing/preprocessing_pass_context.cpp \ diff --git a/src/options/smt_options.toml b/src/options/smt_options.toml index 822f5c022..ce7b3eeba 100644 --- a/src/options/smt_options.toml +++ b/src/options/smt_options.toml @@ -295,6 +295,22 @@ header = "options/smt_options.h" default = "false" help = "use aggressive extended rewriter as a preprocessing pass" +[[option]] + name = "synthRrPrep" + category = "regular" + long = "synth-rr-prep" + type = "bool" + default = "false" + help = "synthesize and output rewrite rules during preprocessing" + +[[option]] + name = "synthRrPrepExtRew" + category = "regular" + long = "synth-rr-prep-ext-rew" + type = "bool" + default = "false" + help = "use the extended rewriter for synthRrPrep" + [[option]] name = "simplifyWithCareEnabled" category = "regular" diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp new file mode 100644 index 000000000..e3e3a548a --- /dev/null +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -0,0 +1,159 @@ +/********************* */ +/*! \file synth_rew_rules.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2, + ** where t1 and t2 are subterms of the input. + **/ + +#include "preprocessing/passes/synth_rew_rules.h" + +#include "options/base_options.h" +#include "options/quantifiers_options.h" +#include "printer/printer.h" +#include "theory/quantifiers/candidate_rewrite_database.h" + +using namespace std; + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +// Attribute for whether we have computed rewrite rules for a given term. +// Notice that this currently must be a global attribute, since if +// we've computed rewrites for a term, we should not compute rewrites for the +// same term in a subcall to another SmtEngine (for instance, when using +// "exact" equivalence checking). +struct SynthRrComputedAttributeId +{ +}; +typedef expr::Attribute + SynthRrComputedAttribute; + +SynthRewRulesPass::SynthRewRulesPass(PreprocessingPassContext* preprocContext) + : PreprocessingPass(preprocContext, "synth-rr"){}; + +PreprocessingPassResult SynthRewRulesPass::applyInternal( + AssertionPipeline* assertionsToPreprocess) +{ + Trace("synth-rr-pass") << "Synthesize rewrite rules from assertions..." + << std::endl; + std::vector& assertions = assertionsToPreprocess->ref(); + + // compute the variables we will be sampling + std::vector vars; + unsigned nsamples = options::sygusSamples(); + + Options& nodeManagerOptions = NodeManager::currentNM()->getOptions(); + + // attribute to mark processed terms + SynthRrComputedAttribute srrca; + + // initialize the candidate rewrite + std::unique_ptr crdg; + std::unordered_map visited; + std::unordered_map::iterator it; + std::vector visit; + // two passes: the first collects the variables, the second registers the + // terms + for (unsigned r = 0; r < 2; r++) + { + visited.clear(); + visit.clear(); + TNode cur; + for (const Node& a : assertions) + { + visit.push_back(a); + do + { + cur = visit.back(); + visit.pop_back(); + it = visited.find(cur); + // if already processed, ignore + if (cur.getAttribute(SynthRrComputedAttribute())) + { + Trace("synth-rr-pass-debug") + << "...already processed " << cur << std::endl; + } + else if (it == visited.end()) + { + Trace("synth-rr-pass-debug") << "...preprocess " << cur << std::endl; + visited[cur] = false; + Kind k = cur.getKind(); + bool isQuant = k == kind::FORALL || k == kind::EXISTS + || k == kind::LAMBDA || k == kind::CHOICE; + // we recurse on this node if it is not a quantified formula + if (!isQuant) + { + visit.push_back(cur); + for (const Node& cc : cur) + { + visit.push_back(cc); + } + } + } + else if (!it->second) + { + Trace("synth-rr-pass-debug") << "...postprocess " << cur << std::endl; + // check if all of the children are valid + // this ensures we do not register terms that have e.g. quantified + // formulas as subterms + bool childrenValid = true; + for (const Node& cc : cur) + { + Assert(visited.find(cc) != visited.end()); + if (!visited[cc]) + { + childrenValid = false; + } + } + if (childrenValid) + { + Trace("synth-rr-pass-debug") + << "...children are valid, check rewrites..." << std::endl; + if (r == 0) + { + if (cur.isVar()) + { + vars.push_back(cur); + } + } + else + { + Trace("synth-rr-pass-debug") << "Add term " << cur << std::endl; + // mark as processed + cur.setAttribute(srrca, true); + bool ret = crdg->addTerm(cur, *nodeManagerOptions.getOut()); + Trace("synth-rr-pass-debug") << "...return " << ret << std::endl; + // if we want only rewrites of minimal size terms, we would set + // childrenValid to false if ret is false here. + } + } + visited[cur] = childrenValid; + } + } while (!visit.empty()); + } + if (r == 0) + { + Trace("synth-rr-pass-debug") + << "Initialize with " << nsamples + << " samples and variables : " << vars << std::endl; + crdg = std::unique_ptr( + new theory::quantifiers::CandidateRewriteDatabaseGen(vars, nsamples)); + } + } + + Trace("synth-rr-pass") << "...finished " << std::endl; + return PreprocessingPassResult::NO_CONFLICT; +} + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 diff --git a/src/preprocessing/passes/synth_rew_rules.h b/src/preprocessing/passes/synth_rew_rules.h new file mode 100644 index 000000000..cf0b491fb --- /dev/null +++ b/src/preprocessing/passes/synth_rew_rules.h @@ -0,0 +1,48 @@ +/********************* */ +/*! \file synth_rew_rules.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A technique for synthesizing candidate rewrites of the form t1 = t2, + ** where t1 and t2 are subterms of the input. + **/ + +#ifndef __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H +#define __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H + +#include "preprocessing/preprocessing_pass.h" +#include "preprocessing/preprocessing_pass_context.h" + +namespace CVC4 { +namespace preprocessing { +namespace passes { + +/** + * This class computes candidate rewrite rules of the form t1 = t2, where + * t1 and t2 are subterms of assertionsToPreprocess. It prints + * "candidate-rewrite" messages on the output stream of options. + * + * In contrast to other preprocessing passes, this pass does not modify + * the set of assertions. + */ +class SynthRewRulesPass : public PreprocessingPass +{ + public: + SynthRewRulesPass(PreprocessingPassContext* preprocContext); + + protected: + PreprocessingPassResult applyInternal( + AssertionPipeline* assertionsToPreprocess) override; +}; + +} // namespace passes +} // namespace preprocessing +} // namespace CVC4 + +#endif /* __CVC4__PREPROCESSING__PASSES__SYNTH_REW_RULES_H */ diff --git a/src/smt/smt_engine.cpp b/src/smt/smt_engine.cpp index 5652eeaa6..ae0e80512 100644 --- a/src/smt/smt_engine.cpp +++ b/src/smt/smt_engine.cpp @@ -80,6 +80,7 @@ #include "preprocessing/passes/static_learning.h" #include "preprocessing/passes/symmetry_breaker.h" #include "preprocessing/passes/symmetry_detect.h" +#include "preprocessing/passes/synth_rew_rules.h" #include "preprocessing/preprocessing_pass.h" #include "preprocessing/preprocessing_pass_context.h" #include "preprocessing/preprocessing_pass_registry.h" @@ -2727,6 +2728,8 @@ void SmtEnginePrivate::finishInit() { new StaticLearning(d_preprocessingPassContext.get())); std::unique_ptr sbProc( new SymBreakerPass(d_preprocessingPassContext.get())); + std::unique_ptr srrProc( + new SynthRewRulesPass(d_preprocessingPassContext.get())); d_preprocessingPassRegistry.registerPass("bool-to-bv", std::move(boolToBv)); d_preprocessingPassRegistry.registerPass("bv-abstraction", std::move(bvAbstract)); @@ -2743,6 +2746,7 @@ void SmtEnginePrivate::finishInit() { d_preprocessingPassRegistry.registerPass("static-learning", std::move(staticLearning)); d_preprocessingPassRegistry.registerPass("sym-break", std::move(sbProc)); + d_preprocessingPassRegistry.registerPass("synth-rr", std::move(srrProc)); } Node SmtEnginePrivate::expandDefinitions(TNode n, unordered_map& cache, bool expandOnly) @@ -4323,6 +4327,12 @@ void SmtEnginePrivate::processAssertions() { ->apply(&d_assertions); } + if (options::synthRrPrep()) + { + // do candidate rewrite rule synthesis + d_preprocessingPassRegistry.getPass("synth-rr")->apply(&d_assertions); + } + Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : pre-simplify" << endl; dumpAssertions("pre-simplify", d_assertions); Chat() << "simplifying assertions..." << endl; diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index 03c39f718..9bbb88699 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -32,37 +32,93 @@ namespace CVC4 { namespace theory { namespace quantifiers { -CandidateRewriteDatabase::CandidateRewriteDatabase() : d_qe(nullptr) {} -void CandidateRewriteDatabase::initialize(QuantifiersEngine* qe, - Node f, +// the number of d_drewrite objects we have allocated (to avoid name conflicts) +static unsigned drewrite_counter = 0; + +CandidateRewriteDatabase::CandidateRewriteDatabase() + : d_qe(nullptr), + d_tds(nullptr), + d_ext_rewrite(nullptr), + d_using_sygus(false) +{ + if (options::sygusRewSynthFilterCong()) + { + // initialize the dynamic rewriter + std::stringstream ss; + ss << "_dyn_rewriter_" << drewrite_counter; + drewrite_counter++; + d_drewrite = std::unique_ptr( + new DynamicRewriter(ss.str(), &d_fake_context)); + d_sampler.setDynamicRewriter(d_drewrite.get()); + } +} +void CandidateRewriteDatabase::initialize(ExtendedRewriter* er, + TypeNode tn, + std::vector& vars, unsigned nsamples, - bool useSygusType) + bool unique_type_ids) +{ + d_candidate = Node::null(); + d_type = tn; + d_using_sygus = false; + d_qe = nullptr; + d_tds = nullptr; + d_ext_rewrite = er; + d_sampler.initialize(tn, vars, nsamples, unique_type_ids); +} + +void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe, + Node f, + unsigned nsamples, + bool useSygusType) { - d_qe = qe; d_candidate = f; - d_sampler.initializeSygusExt(d_qe, f, nsamples, useSygusType); + d_type = f.getType(); + Assert(d_type.isDatatype()); + Assert(static_cast(d_type.toType()).getDatatype().isSygus()); + d_using_sygus = true; + d_qe = qe; + d_tds = d_qe->getTermDatabaseSygus(); + d_ext_rewrite = d_tds->getExtRewriter(); + d_sampler.initializeSygus(d_tds, f, nsamples, useSygusType); } -bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) +bool CandidateRewriteDatabase::addTerm(Node sol, + std::ostream& out, + bool& rew_print) { bool is_unique_term = true; - TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus(); Node eq_sol = d_sampler.registerTerm(sol); // eq_sol is a candidate solution that is equivalent to sol if (eq_sol != sol) { - CegInstantiation* cei = d_qe->getCegInstantiation(); is_unique_term = false; // if eq_sol is null, then we have an uninteresting candidate rewrite, // e.g. one that is alpha-equivalent to another. - bool success = true; if (!eq_sol.isNull()) { - ExtendedRewriter* er = sygusDb->getExtRewriter(); - Node solb = sygusDb->sygusToBuiltin(sol); - Node solbr = er->extendedRewrite(solb); - Node eq_solb = sygusDb->sygusToBuiltin(eq_sol); - Node eq_solr = er->extendedRewrite(eq_solb); + // get the actual term + Node solb = sol; + Node eq_solb = eq_sol; + if (d_using_sygus) + { + Assert(d_tds != nullptr); + solb = d_tds->sygusToBuiltin(sol); + eq_solb = d_tds->sygusToBuiltin(eq_sol); + } + // get the rewritten form + Node solbr; + Node eq_solr; + if (d_ext_rewrite != nullptr) + { + solbr = d_ext_rewrite->extendedRewrite(solb); + eq_solr = d_ext_rewrite->extendedRewrite(eq_solb); + } + else + { + solbr = Rewriter::rewrite(solb); + eq_solr = Rewriter::rewrite(eq_solb); + } bool verified = false; Trace("rr-check") << "Check candidate rewrite..." << std::endl; // verify it if applicable @@ -108,27 +164,36 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) if (r.asSatisfiabilityResult().isSat() == Result::SAT) { Trace("rr-check") << "...rewrite does not hold for: " << std::endl; - success = false; is_unique_term = true; std::vector vars; d_sampler.getVariables(vars); std::vector pt; for (const Node& v : vars) { - std::map::iterator itf = fv_index.find(v); Node val; - if (itf == fv_index.end()) + Node refv = v; + // if a bound variable, map to the skolem we introduce before + // looking up the model value + if (v.getKind() == BOUND_VARIABLE) { - // not in conjecture, can use arbitrary value - val = v.getType().mkGroundTerm(); + std::map::iterator itf = fv_index.find(v); + if (itf == fv_index.end()) + { + // not in conjecture, can use arbitrary value + val = v.getType().mkGroundTerm(); + } + else + { + // get the model value of its skolem + refv = sks[itf->second]; + } } - else + if (val.isNull()) { - // get the model value of its skolem - Node sk = sks[itf->second]; - val = Node::fromExpr(rrChecker.getValue(sk.toExpr())); - Trace("rr-check") << " " << v << " -> " << val << std::endl; + Assert(!refv.isNull() && refv.getKind() != BOUND_VARIABLE); + val = Node::fromExpr(rrChecker.getValue(refv.toExpr())); } + Trace("rr-check") << " " << v << " -> " << val << std::endl; pt.push_back(val); } d_sampler.addSamplePoint(pt); @@ -145,22 +210,29 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) else { // just insist that constants are not relevant pairs - success = !solb.isConst() || !eq_solb.isConst(); + is_unique_term = solb.isConst() && eq_solb.isConst(); } - if (success) + if (!is_unique_term) { // register this as a relevant pair (helps filtering) d_sampler.registerRelevantPair(sol, eq_sol); // The analog of terms sol and eq_sol are equivalent under // sample points but do not rewrite to the same term. Hence, // this indicates a candidate rewrite. - Printer* p = Printer::getPrinter(options::outputLanguage()); out << "(" << (verified ? "" : "candidate-") << "rewrite "; - p->toStreamSygus(out, sol); - out << " "; - p->toStreamSygus(out, eq_sol); + if (d_using_sygus) + { + Printer* p = Printer::getPrinter(options::outputLanguage()); + p->toStreamSygus(out, sol); + out << " "; + p->toStreamSygus(out, eq_sol); + } + else + { + out << sol << " " << eq_sol; + } out << ")" << std::endl; - ++(cei->d_statistics.d_candidate_rewrites_print); + rew_print = true; // debugging information if (Trace.isOn("sygus-rr-debug")) { @@ -169,32 +241,33 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) Trace("sygus-rr-debug") << "; candidate #2 ext-rewrites to: " << eq_solr << std::endl; } - if (options::sygusRewSynthAccel()) + if (options::sygusRewSynthAccel() && d_using_sygus) { + Assert(d_tds != nullptr); // Add a symmetry breaking clause that excludes the larger // of sol and eq_sol. This effectively states that we no longer // wish to enumerate any term that contains sol (resp. eq_sol) // as a subterm. Node exc_sol = sol; - unsigned sz = sygusDb->getSygusTermSize(sol); - unsigned eqsz = sygusDb->getSygusTermSize(eq_sol); + unsigned sz = d_tds->getSygusTermSize(sol); + unsigned eqsz = d_tds->getSygusTermSize(eq_sol); if (eqsz > sz) { sz = eqsz; exc_sol = eq_sol; } TypeNode ptn = d_candidate.getType(); - Node x = sygusDb->getFreeVar(ptn, 0); - Node lem = - sygusDb->getExplain()->getExplanationForEquality(x, exc_sol); + Node x = d_tds->getFreeVar(ptn, 0); + Node lem = d_tds->getExplain()->getExplanationForEquality(x, exc_sol); lem = lem.negate(); Trace("sygus-rr-sb") << "Symmetry breaking lemma : " << lem << std::endl; - sygusDb->registerSymBreakLemma(d_candidate, lem, ptn, sz); + d_tds->registerSymBreakLemma(d_candidate, lem, ptn, sz); } } } // We count this as a rewrite if we did not explicitly rule it out. + // The value of is_unique_term is false iff this call resulted in a rewrite. // Notice that when --sygus-rr-synth-check is enabled, // statistics on number of candidate rewrite rules is // an accurate count of (#enumerated_terms-#unique_terms) only if @@ -203,14 +276,52 @@ bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) // rule is not useful since its variables are unordered, whereby // it discards it as a redundant candidate rewrite rule before // checking its correctness. - if (success) - { - ++(cei->d_statistics.d_candidate_rewrites); - } } return is_unique_term; } +bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out) +{ + bool rew_print = false; + return addTerm(sol, out, rew_print); +} + +CandidateRewriteDatabaseGen::CandidateRewriteDatabaseGen( + std::vector& vars, unsigned nsamples) + : d_vars(vars.begin(), vars.end()), d_nsamples(nsamples) +{ +} + +bool CandidateRewriteDatabaseGen::addTerm(Node n, std::ostream& out) +{ + ExtendedRewriter* er = nullptr; + if (options::synthRrPrepExtRew()) + { + er = &d_ext_rewrite; + } + Node nr; + if (er == nullptr) + { + nr = Rewriter::rewrite(n); + } + else + { + nr = er->extendedRewrite(n); + } + TypeNode tn = nr.getType(); + std::map::iterator itc = d_cdbs.find(tn); + if (itc == d_cdbs.end()) + { + Trace("synth-rr-dbg") << "Initialize database for " << tn << std::endl; + // initialize with the extended rewriter owned by this class + d_cdbs[tn].initialize(er, tn, d_vars, d_nsamples, true); + itc = d_cdbs.find(tn); + Trace("synth-rr-dbg") << "...finish." << std::endl; + } + Trace("synth-rr-dbg") << "Add term " << nr << " for " << tn << std::endl; + return itc->second.addTerm(nr, out); +} + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/candidate_rewrite_database.h b/src/theory/quantifiers/candidate_rewrite_database.h index 9ca946d26..a2a6c5745 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.h +++ b/src/theory/quantifiers/candidate_rewrite_database.h @@ -18,6 +18,9 @@ #define __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_DATABASE_H #include +#include +#include +#include #include "theory/quantifiers/sygus_sampler.h" namespace CVC4 { @@ -43,7 +46,32 @@ class CandidateRewriteDatabase ~CandidateRewriteDatabase() {} /** Initialize this class * - * qe : pointer to quantifiers engine, + * er : pointer to the extended rewriter (if any) we are using to compute + * candidate rewrites, + * tn : the return type of terms we will be testing with this class, + * vars : the variables we are testing substitutions for, + * nsamples : number of sample points this class will test, + * unique_type_ids : if this is set to true, then each variable is treated + * as unique. This affects whether or not a rewrite rule is considered + * redundant or not. For example the rewrite f(y)=y is redundant if + * f(x)=x has also been printed as a rewrite and x and y have the same type + * id (see SygusSampler for details). On the other hand, when a candidate + * rewrite database is initialized with sygus below, the type ids of the + * (sygus formal argument list) variables are always computed and used. + */ + void initialize(ExtendedRewriter* er, + TypeNode tn, + std::vector& vars, + unsigned nsamples, + bool unique_type_ids = false); + /** Initialize this class + * + * Serves the same purpose as the above function, but we will be using + * sygus to enumerate terms and generate samples. + * + * qe : pointer to quantifiers engine. We use the sygus term database of this + * quantifiers engine, and the extended rewriter of the corresponding term + * database when computing candidate rewrites, * f : a term of some SyGuS datatype type whose values we will be * testing under the free variables in the grammar of f. This is the * "candidate variable" CegConjecture::d_candidates, @@ -55,28 +83,44 @@ class CandidateRewriteDatabase * * These arguments are used to initialize the sygus sampler class. */ - void initialize(QuantifiersEngine* qe, - Node f, - unsigned nsamples, - bool useSygusType); + void initializeSygus(QuantifiersEngine* qe, + Node f, + unsigned nsamples, + bool useSygusType); /** add term * * Notifies this class that the solution sol was enumerated. This may * cause a candidate-rewrite to be printed on the output stream out. + * We return true if the term sol is distinct (up to equivalence) with + * all previous terms added to this class. The argument rew_print is set to + * true if this class printed a rewrite. */ + bool addTerm(Node sol, std::ostream& out, bool& rew_print); bool addTerm(Node sol, std::ostream& out); private: /** reference to quantifier engine */ QuantifiersEngine* d_qe; - /** the function-to-synthesize we are testing */ + /** pointer to the sygus term database of d_qe */ + TermDbSygus* d_tds; + /** pointer to the extended rewriter object we are using */ + ExtendedRewriter* d_ext_rewrite; + /** the (sygus or builtin) type of terms we are testing */ + TypeNode d_type; + /** the function-to-synthesize we are testing (if sygus) */ Node d_candidate; + /** whether we are using sygus */ + bool d_using_sygus; /** sygus sampler objects for each program variable * * This is used for the sygusRewSynth() option to synthesize new candidate * rewrite rules. */ SygusSamplerExt d_sampler; + /** a (dummy) user context, used for d_drewrite */ + context::UserContext d_fake_context; + /** dynamic rewriter class */ + std::unique_ptr d_drewrite; /** * Cache of skolems for each free variable that appears in a synthesis check * (for --sygus-rr-synth-check). @@ -84,6 +128,41 @@ class CandidateRewriteDatabase std::map d_fv_to_skolem; }; +/** + * This class generates and stores candidate rewrite databases for multiple + * types as needed. + */ +class CandidateRewriteDatabaseGen +{ + public: + /** constructor + * + * vars : the variables we are testing substitutions for, for all types, + * nsamples : number of sample points this class will test for all types. + */ + CandidateRewriteDatabaseGen(std::vector& vars, unsigned nsamples); + /** add term + * + * This registers term n with this class. We generate the candidate rewrite + * database of the appropriate type (if not allocated already), and register + * n with this database. This may result in "candidate-rewrite" being + * printed on the output stream out. + */ + bool addTerm(Node n, std::ostream& out); + + private: + /** reference to quantifier engine */ + QuantifiersEngine* d_qe; + /** the variables */ + std::vector d_vars; + /** the number of samples */ + unsigned d_nsamples; + /** candidate rewrite databases for each type */ + std::map d_cdbs; + /** an extended rewriter object */ + ExtendedRewriter d_ext_rewrite; +}; + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/dynamic_rewrite.cpp b/src/theory/quantifiers/dynamic_rewrite.cpp index 352d6892f..ef1cb3a9d 100644 --- a/src/theory/quantifiers/dynamic_rewrite.cpp +++ b/src/theory/quantifiers/dynamic_rewrite.cpp @@ -23,9 +23,9 @@ namespace CVC4 { namespace theory { namespace quantifiers { -DynamicRewriter::DynamicRewriter(const std::string& name, QuantifiersEngine* qe) - : d_equalityEngine(qe->getUserContext(), "DynamicRewriter::" + name, true), - d_rewrites(qe->getUserContext()) +DynamicRewriter::DynamicRewriter(const std::string& name, + context::UserContext* u) + : d_equalityEngine(u, "DynamicRewriter::" + name, true), d_rewrites(u) { d_equalityEngine.addFunctionKind(kind::APPLY_UF); } @@ -42,6 +42,11 @@ void DynamicRewriter::addRewrite(Node a, Node b) // add to the equality engine Node ai = toInternal(a); Node bi = toInternal(b); + if (ai.isNull() || bi.isNull()) + { + Trace("dyn-rewrite") << "...not internalizable." << std::endl; + return; + } Trace("dyn-rewrite-debug") << "Internal : " << ai << " " << bi << std::endl; Trace("dyn-rewrite-debug") << "assert eq..." << std::endl; @@ -58,11 +63,19 @@ bool DynamicRewriter::areEqual(Node a, Node b) { return true; } + Trace("dyn-rewrite-debug") << "areEqual? : " << a << " " << b << std::endl; // add to the equality engine Node ai = toInternal(a); Node bi = toInternal(b); + if (ai.isNull() || bi.isNull()) + { + Trace("dyn-rewrite") << "...not internalizable." << std::endl; + return false; + } + Trace("dyn-rewrite-debug") << "internal : " << ai << " " << bi << std::endl; d_equalityEngine.addTerm(ai); d_equalityEngine.addTerm(bi); + Trace("dyn-rewrite-debug") << "...added terms" << std::endl; return d_equalityEngine.areEqual(ai, bi); } @@ -84,6 +97,12 @@ Node DynamicRewriter::toInternal(Node a) if (a.getKind() != APPLY_UF) { op = d_ois_trie[op].getSymbol(a); + // if this term involves an argument that is not of first class type, + // we cannot reason about it. This includes operators like str.in-re. + if (op.isNull()) + { + return Node::null(); + } } children.push_back(op); } @@ -120,6 +139,11 @@ Node DynamicRewriter::OpInternalSymTrie::getSymbol(Node n) OpInternalSymTrie* curr = this; for (unsigned i = 0, size = ctypes.size(); i < size; i++) { + // cannot handle certain types (e.g. regular expressions or functions) + if (!ctypes[i].isFirstClass()) + { + return Node::null(); + } curr = &(curr->d_children[ctypes[i]]); } if (!curr->d_sym.isNull()) diff --git a/src/theory/quantifiers/dynamic_rewrite.h b/src/theory/quantifiers/dynamic_rewrite.h index 0c115d8a1..75f668b11 100644 --- a/src/theory/quantifiers/dynamic_rewrite.h +++ b/src/theory/quantifiers/dynamic_rewrite.h @@ -20,7 +20,6 @@ #include #include "context/cdlist.h" -#include "theory/quantifiers_engine.h" #include "theory/uf/equality_engine.h" namespace CVC4 { @@ -55,7 +54,7 @@ class DynamicRewriter typedef context::CDList NodeList; public: - DynamicRewriter(const std::string& name, QuantifiersEngine* qe); + DynamicRewriter(const std::string& name, context::UserContext* u); ~DynamicRewriter() {} /** inform this class that the equality a = b holds. */ void addRewrite(Node a, Node b); diff --git a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp index 61869b355..3bb0fc51a 100644 --- a/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/sygus/ce_guided_conjecture.cpp @@ -658,11 +658,20 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation d_crrdb.find(prog); if (its == d_crrdb.end()) { - d_crrdb[prog].initialize( + d_crrdb[prog].initializeSygus( d_qe, d_candidates[i], options::sygusSamples(), true); its = d_crrdb.find(prog); } - is_unique_term = d_crrdb[prog].addTerm(sol, out); + bool rew_print = false; + is_unique_term = d_crrdb[prog].addTerm(sol, out, rew_print); + if (rew_print) + { + ++(cei->d_statistics.d_candidate_rewrites_print); + } + if (!is_unique_term) + { + ++(cei->d_statistics.d_candidate_rewrites); + } } if (is_unique_term) { diff --git a/src/theory/quantifiers/sygus/term_database_sygus.cpp b/src/theory/quantifiers/sygus/term_database_sygus.cpp index 26f26a145..c6976ac62 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.cpp +++ b/src/theory/quantifiers/sygus/term_database_sygus.cpp @@ -15,15 +15,15 @@ #include "theory/quantifiers/sygus/term_database_sygus.h" #include "base/cvc4_check.h" +#include "options/base_options.h" #include "options/quantifiers_options.h" +#include "printer/printer.h" #include "theory/arith/arith_msum.h" #include "theory/datatypes/datatypes_rewriter.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/term_database.h" #include "theory/quantifiers/term_util.h" #include "theory/quantifiers_engine.h" -#include "options/base_options.h" -#include "printer/printer.h" using namespace CVC4::kind; diff --git a/src/theory/quantifiers/sygus/term_database_sygus.h b/src/theory/quantifiers/sygus/term_database_sygus.h index 286533570..44139cf0d 100644 --- a/src/theory/quantifiers/sygus/term_database_sygus.h +++ b/src/theory/quantifiers/sygus/term_database_sygus.h @@ -185,8 +185,10 @@ class TermDbSygus { * form of bn [ args / vars(tn) ], where vars(tn) is the sygus variable * list for type tn (see Datatype::getSygusVarList). */ - Node evaluateBuiltin(TypeNode tn, Node bn, std::vector& args, -bool tryEval = true); + Node evaluateBuiltin(TypeNode tn, + Node bn, + std::vector& args, + bool tryEval = true); /** evaluate with unfolding * * n is any term that may involve sygus evaluation functions. This function diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index c290c027a..8da65e4ca 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -32,7 +32,8 @@ SygusSampler::SygusSampler() void SygusSampler::initialize(TypeNode tn, std::vector& vars, - unsigned nsamples) + unsigned nsamples, + bool unique_type_ids) { d_tds = nullptr; d_use_sygus_type = false; @@ -53,15 +54,23 @@ void SygusSampler::initialize(TypeNode tn, { TypeNode svt = sv.getType(); unsigned tnid = 0; - std::map::iterator itt = type_to_type_id.find(svt); - if (itt == type_to_type_id.end()) + if (unique_type_ids) { - type_to_type_id[svt] = type_id_counter; + tnid = type_id_counter; type_id_counter++; } else { - tnid = itt->second; + std::map::iterator itt = type_to_type_id.find(svt); + if (itt == type_to_type_id.end()) + { + type_to_type_id[svt] = type_id_counter; + type_id_counter++; + } + else + { + tnid = itt->second; + } } Trace("sygus-sample-debug") << "Type id for " << sv << " is " << tnid << std::endl; @@ -586,7 +595,7 @@ Node SygusSampler::getRandomValue(TypeNode tn) if (!s.isNull() && !r.isNull()) { Rational sr = s.getConst(); - Rational rr = s.getConst(); + Rational rr = r.getConst(); if (rr.sgn() == 0) { return s; @@ -597,7 +606,19 @@ Node SygusSampler::getRandomValue(TypeNode tn) } } } - return Node::null(); + // default: use type enumerator + unsigned counter = 0; + while (Random::getRandom().pickWithProb(0.5)) + { + counter++; + } + Node ret = d_tenum.getEnumerateTerm(tn, counter); + if (ret.isNull()) + { + // beyond bounds, return the first + ret = d_tenum.getEnumerateTerm(tn, 0); + } + return ret; } Node SygusSampler::getSygusRandomValue(TypeNode tn, @@ -719,28 +740,23 @@ void SygusSampler::registerSygusType(TypeNode tn) } } -SygusSamplerExt::SygusSamplerExt() : d_ssenm(*this) {} +SygusSamplerExt::SygusSamplerExt() : d_drewrite(nullptr), d_ssenm(*this) {} -void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe, - Node f, - unsigned nsamples, - bool useSygusType) +void SygusSamplerExt::initializeSygus(TermDbSygus* tds, + Node f, + unsigned nsamples, + bool useSygusType) { - SygusSampler::initializeSygus( - qe->getTermDatabaseSygus(), f, nsamples, useSygusType); - - // initialize the dynamic rewriter - std::stringstream ss; - ss << f; - if (options::sygusRewSynthFilterCong()) - { - d_drewrite = - std::unique_ptr(new DynamicRewriter(ss.str(), qe)); - } + SygusSampler::initializeSygus(tds, f, nsamples, useSygusType); d_pairs.clear(); d_match_trie.clear(); } +void SygusSamplerExt::setDynamicRewriter(DynamicRewriter* dr) +{ + d_drewrite = dr; +} + Node SygusSamplerExt::registerTerm(Node n, bool forceKeep) { Node eq_n = SygusSampler::registerTerm(n, forceKeep); @@ -896,6 +912,9 @@ bool SygusSamplerExt::notify(Node s, for (unsigned i = 0, size = vars.size(); i < size; i++) { Trace("sse-match") << " " << vars[i] << " -> " << subs[i] << std::endl; + // TODO (#1923) ensure that we use an internal representation to + // ensure polymorphism is handled correctly + Assert(vars[i].getType().isComparableTo(subs[i].getType())); } } Assert(it != d_pairs.end()); diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index fcd35613b..d323b36bd 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -21,6 +21,7 @@ #include "theory/quantifiers/dynamic_rewrite.h" #include "theory/quantifiers/lazy_trie.h" #include "theory/quantifiers/sygus/term_database_sygus.h" +#include "theory/quantifiers/term_enumeration.h" namespace CVC4 { namespace theory { @@ -69,14 +70,20 @@ class SygusSampler : public LazyTrieEvaluator /** initialize * - * tn : the return type of terms we will be testing with this class - * vars : the variables we are testing substitutions for - * nsamples : number of sample points this class will test. + * tn : the return type of terms we will be testing with this class, + * vars : the variables we are testing substitutions for, + * nsamples : number of sample points this class will test, + * unique_type_ids : if this is set to true, then we consider each variable + * in vars to have a unique "type id". A type id is a finer-grained notion of + * type that is used to determine when a rewrite rule is redundant. */ - void initialize(TypeNode tn, std::vector& vars, unsigned nsamples); + virtual void initialize(TypeNode tn, + std::vector& vars, + unsigned nsamples, + bool unique_type_ids = false); /** initialize sygus * - * tds : pointer to sygus database, + * qe : pointer to quantifiers engine, * f : a term of some SyGuS datatype type whose values we will be * testing under the free variables in the grammar of f, * nsamples : number of sample points this class will test, @@ -85,10 +92,10 @@ class SygusSampler : public LazyTrieEvaluator * terms of the analog of the type of f, that is, the builtin type that * f's type encodes in the deep embedding. */ - void initializeSygus(TermDbSygus* tds, - Node f, - unsigned nsamples, - bool useSygusType); + virtual void initializeSygus(TermDbSygus* tds, + Node f, + unsigned nsamples, + bool useSygusType); /** register term n with this sampler database * * forceKeep is whether we wish to force that n is chosen as a representative @@ -145,6 +152,8 @@ class SygusSampler : public LazyTrieEvaluator protected: /** sygus term database of d_qe */ TermDbSygus* d_tds; + /** term enumerator object (used for random sampling) */ + TermEnumeration d_tenum; /** samples */ std::vector > d_samples; /** data structure to check duplication of sample points */ @@ -330,11 +339,19 @@ class SygusSamplerExt : public SygusSampler { public: SygusSamplerExt(); - /** initialize extended */ - void initializeSygusExt(QuantifiersEngine* qe, - Node f, - unsigned nsamples, - bool useSygusType); + /** initialize */ + void initializeSygus(TermDbSygus* tds, + Node f, + unsigned nsamples, + bool useSygusType) override; + /** set dynamic rewriter + * + * This tells this class to use the dynamic rewriter object dr. This utility + * is used to query whether pairs of terms are already entailed to be + * equal based on previous rewrite rules. + */ + void setDynamicRewriter(DynamicRewriter* dr); + /** register term n with this sampler database * * For each call to registerTerm( t, ... ) that returns s, we say that @@ -366,7 +383,6 @@ class SygusSamplerExt : public SygusSampler * d_drewrite utility, or is an instance of a previous pair */ Node registerTerm(Node n, bool forceKeep = false) override; - /** register relevant pair * * This should be called after registerTerm( n ) returns eq_n. @@ -375,8 +391,8 @@ class SygusSamplerExt : public SygusSampler void registerRelevantPair(Node n, Node eq_n); private: - /** dynamic rewriter class */ - std::unique_ptr d_drewrite; + /** pointer to the dynamic rewriter class */ + DynamicRewriter* d_drewrite; //----------------------------match filtering /**