From: Andrew Reynolds Date: Wed, 4 Jul 2018 13:31:14 +0000 (+0100) Subject: Reorganize candidate rewrite rule filtering (#2116) X-Git-Tag: cvc5-1.0.0~4916 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9a8d9420f03ba27fc5cbb9674b0c809ecc53e85e;p=cvc5.git Reorganize candidate rewrite rule filtering (#2116) --- diff --git a/src/Makefile.am b/src/Makefile.am index b5da564cf..917fc6ef3 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -393,6 +393,8 @@ libcvc4_la_SOURCES = \ theory/quantifiers/bv_inverter.h \ theory/quantifiers/candidate_rewrite_database.cpp \ theory/quantifiers/candidate_rewrite_database.h \ + theory/quantifiers/candidate_rewrite_filter.cpp \ + theory/quantifiers/candidate_rewrite_filter.h \ theory/quantifiers/cegqi/ceg_instantiator.cpp \ theory/quantifiers/cegqi/ceg_instantiator.h \ theory/quantifiers/cegqi/ceg_arith_instantiator.cpp \ diff --git a/src/theory/quantifiers/candidate_rewrite_database.cpp b/src/theory/quantifiers/candidate_rewrite_database.cpp index 9bbb88699..a5a35f89d 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.cpp +++ b/src/theory/quantifiers/candidate_rewrite_database.cpp @@ -32,25 +32,12 @@ namespace CVC4 { namespace theory { namespace quantifiers { -// the number of d_drewrite objects we have allocated (to avoid name conflicts) -static unsigned drewrite_counter = 0; - CandidateRewriteDatabase::CandidateRewriteDatabase() : d_qe(nullptr), d_tds(nullptr), d_ext_rewrite(nullptr), d_using_sygus(false) { - if (options::sygusRewSynthFilterCong()) - { - // initialize the dynamic rewriter - std::stringstream ss; - ss << "_dyn_rewriter_" << drewrite_counter; - drewrite_counter++; - d_drewrite = std::unique_ptr( - new DynamicRewriter(ss.str(), &d_fake_context)); - d_sampler.setDynamicRewriter(d_drewrite.get()); - } } void CandidateRewriteDatabase::initialize(ExtendedRewriter* er, TypeNode tn, @@ -65,6 +52,7 @@ void CandidateRewriteDatabase::initialize(ExtendedRewriter* er, d_tds = nullptr; d_ext_rewrite = er; d_sampler.initialize(tn, vars, nsamples, unique_type_ids); + d_crewrite_filter.initialize(&d_sampler, nullptr, false); } void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe, @@ -81,6 +69,7 @@ void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe, d_tds = d_qe->getTermDatabaseSygus(); d_ext_rewrite = d_tds->getExtRewriter(); d_sampler.initializeSygus(d_tds, f, nsamples, useSygusType); + d_crewrite_filter.initialize(&d_sampler, d_tds, true); } bool CandidateRewriteDatabase::addTerm(Node sol, @@ -93,9 +82,8 @@ bool CandidateRewriteDatabase::addTerm(Node sol, if (eq_sol != sol) { is_unique_term = false; - // if eq_sol is null, then we have an uninteresting candidate rewrite, - // e.g. one that is alpha-equivalent to another. - if (!eq_sol.isNull()) + // should we filter the pair? + if (!d_crewrite_filter.filterPair(sol, eq_sol)) { // get the actual term Node solb = sol; @@ -215,7 +203,7 @@ bool CandidateRewriteDatabase::addTerm(Node sol, if (!is_unique_term) { // register this as a relevant pair (helps filtering) - d_sampler.registerRelevantPair(sol, eq_sol); + d_crewrite_filter.registerRelevantPair(sol, eq_sol); // The analog of terms sol and eq_sol are equivalent under // sample points but do not rewrite to the same term. Hence, // this indicates a candidate rewrite. diff --git a/src/theory/quantifiers/candidate_rewrite_database.h b/src/theory/quantifiers/candidate_rewrite_database.h index a2a6c5745..7f51043e2 100644 --- a/src/theory/quantifiers/candidate_rewrite_database.h +++ b/src/theory/quantifiers/candidate_rewrite_database.h @@ -21,6 +21,7 @@ #include #include #include +#include "theory/quantifiers/candidate_rewrite_filter.h" #include "theory/quantifiers/sygus_sampler.h" namespace CVC4 { @@ -116,11 +117,9 @@ class CandidateRewriteDatabase * This is used for the sygusRewSynth() option to synthesize new candidate * rewrite rules. */ - SygusSamplerExt d_sampler; - /** a (dummy) user context, used for d_drewrite */ - context::UserContext d_fake_context; - /** dynamic rewriter class */ - std::unique_ptr d_drewrite; + SygusSampler d_sampler; + /** candidate rewrite filter */ + CandidateRewriteFilter d_crewrite_filter; /** * Cache of skolems for each free variable that appears in a synthesis check * (for --sygus-rr-synth-check). diff --git a/src/theory/quantifiers/candidate_rewrite_filter.cpp b/src/theory/quantifiers/candidate_rewrite_filter.cpp new file mode 100644 index 000000000..68a3abe37 --- /dev/null +++ b/src/theory/quantifiers/candidate_rewrite_filter.cpp @@ -0,0 +1,413 @@ +/********************* */ +/*! \file candidate_rewrite_filter.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implements techniques for candidate rewrite rule filtering, which + ** filters the output of --sygus-rr-synth. The classes in this file implement + ** filtering based on congruence, variable ordering, and matching. + **/ + +#include "theory/quantifiers/candidate_rewrite_filter.h" + +#include "options/base_options.h" +#include "options/quantifiers_options.h" +#include "printer/printer.h" + +using namespace CVC4::kind; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +bool MatchTrie::getMatches(Node n, NotifyMatch* ntm) +{ + std::vector vars; + std::vector subs; + std::map smap; + + std::vector > visit; + std::vector visit_trie; + std::vector visit_var_index; + std::vector visit_bound_var; + + visit.push_back(std::vector{n}); + visit_trie.push_back(this); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + while (!visit.empty()) + { + std::vector cvisit = visit.back(); + MatchTrie* curr = visit_trie.back(); + if (cvisit.empty()) + { + Assert(n + == curr->d_data.substitute( + vars.begin(), vars.end(), subs.begin(), subs.end())); + Trace("crf-match-debug") << "notify : " << curr->d_data << std::endl; + if (!ntm->notify(n, curr->d_data, vars, subs)) + { + return false; + } + visit.pop_back(); + visit_trie.pop_back(); + visit_var_index.pop_back(); + visit_bound_var.pop_back(); + } + else + { + Node cn = cvisit.back(); + Trace("crf-match-debug") << "traverse : " << cn << " at depth " + << visit.size() << std::endl; + unsigned index = visit.size() - 1; + int vindex = visit_var_index[index]; + if (vindex == -1) + { + if (!cn.isVar()) + { + Node op = cn.hasOperator() ? cn.getOperator() : cn; + unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0; + std::map::iterator itu = + curr->d_children[op].find(nchild); + if (itu != curr->d_children[op].end()) + { + // recurse on the operator or self + cvisit.pop_back(); + if (cn.hasOperator()) + { + for (const Node& cnc : cn) + { + cvisit.push_back(cnc); + } + } + Trace("crf-match-debug") << "recurse op : " << op << std::endl; + visit.push_back(cvisit); + visit_trie.push_back(&itu->second); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + } + } + visit_var_index[index]++; + } + else + { + // clean up if we previously bound a variable + if (visit_bound_var[index]) + { + Assert(!vars.empty()); + smap.erase(vars.back()); + vars.pop_back(); + subs.pop_back(); + visit_bound_var[index] = false; + } + + if (vindex == static_cast(curr->d_vars.size())) + { + Trace("crf-match-debug") + << "finished checking " << curr->d_vars.size() + << " variables at depth " << visit.size() << std::endl; + // finished + visit.pop_back(); + visit_trie.pop_back(); + visit_var_index.pop_back(); + visit_bound_var.pop_back(); + } + else + { + Trace("crf-match-debug") << "check variable #" << vindex + << " at depth " << visit.size() << std::endl; + Assert(vindex < static_cast(curr->d_vars.size())); + // recurse on variable? + Node var = curr->d_vars[vindex]; + bool recurse = true; + // check if it is already bound + std::map::iterator its = smap.find(var); + if (its != smap.end()) + { + if (its->second != cn) + { + recurse = false; + } + } + else + { + vars.push_back(var); + subs.push_back(cn); + smap[var] = cn; + visit_bound_var[index] = true; + } + if (recurse) + { + Trace("crf-match-debug") << "recurse var : " << var << std::endl; + cvisit.pop_back(); + visit.push_back(cvisit); + visit_trie.push_back(&curr->d_children[var][0]); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + } + visit_var_index[index]++; + } + } + } + } + return true; +} + +void MatchTrie::addTerm(Node n) +{ + std::vector visit; + visit.push_back(n); + MatchTrie* curr = this; + while (!visit.empty()) + { + Node cn = visit.back(); + visit.pop_back(); + if (cn.hasOperator()) + { + curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]); + for (const Node& cnc : cn) + { + visit.push_back(cnc); + } + } + else + { + if (cn.isVar() + && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn) + == curr->d_vars.end()) + { + curr->d_vars.push_back(cn); + } + curr = &(curr->d_children[cn][0]); + } + } + curr->d_data = n; +} + +void MatchTrie::clear() +{ + d_children.clear(); + d_vars.clear(); + d_data = Node::null(); +} + +// the number of d_drewrite objects we have allocated (to avoid name conflicts) +static unsigned drewrite_counter = 0; + +CandidateRewriteFilter::CandidateRewriteFilter() + : d_ss(nullptr), + d_tds(nullptr), + d_use_sygus_type(false), + d_drewrite(nullptr), + d_ssenm(*this) +{ +} + +void CandidateRewriteFilter::initialize(SygusSampler* ss, + TermDbSygus* tds, + bool useSygusType) +{ + d_ss = ss; + d_use_sygus_type = false; + d_tds = tds; + // initialize members of this class + d_match_trie.clear(); + d_pairs.clear(); + if (options::sygusRewSynthFilterCong()) + { + // initialize the dynamic rewriter + std::stringstream ss; + ss << "_dyn_rewriter_" << drewrite_counter; + drewrite_counter++; + d_drewrite = std::unique_ptr( + new DynamicRewriter(ss.str(), &d_fake_context)); + } +} + +bool CandidateRewriteFilter::filterPair(Node n, Node eq_n) +{ + Node bn = n; + Node beq_n = eq_n; + if (d_use_sygus_type) + { + bn = d_tds->sygusToBuiltin(n); + beq_n = d_tds->sygusToBuiltin(eq_n); + } + Trace("cr-filter") << "crewriteFilter : " << bn << "..." << beq_n + << std::endl; + // whether we will keep this pair + bool keep = true; + + // ----- check ordering redundancy + if (options::sygusRewSynthFilterOrder()) + { + bool nor = d_ss->isOrdered(bn); + bool eqor = d_ss->isOrdered(beq_n); + Trace("cr-filter-debug") << "Ordered? : " << nor << " " << eqor + << std::endl; + if (eqor || nor) + { + // if only one is ordered, then we require that the ordered one's + // variables cannot be a strict subset of the variables of the other. + if (!eqor) + { + if (d_ss->containsFreeVariables(beq_n, bn, true)) + { + keep = false; + } + else + { + // if the previous value stored was unordered, but n is + // ordered, we prefer n. Thus, we force its addition to the + // sampler database. + d_ss->registerTerm(n, true); + } + } + else if (!nor) + { + keep = !d_ss->containsFreeVariables(bn, beq_n, true); + } + } + else + { + keep = false; + } + if (!keep) + { + Trace("cr-filter") << "...redundant (unordered)" << std::endl; + } + } + + // ----- check rewriting redundancy + if (keep && d_drewrite != nullptr) + { + Trace("cr-filter-debug") << "Check equal rewrite pair..." << std::endl; + if (d_drewrite->areEqual(bn, beq_n)) + { + // must be unique according to the dynamic rewriter + Trace("cr-filter") << "...redundant (rewritable)" << std::endl; + keep = false; + } + } + + if (options::sygusRewSynthFilterMatch()) + { + // ----- check matchable + // check whether the pair is matchable with a previous one + d_curr_pair_rhs = beq_n; + Trace("crf-match") << "CRF check matches : " << bn << " [rhs = " << beq_n + << "]..." << std::endl; + if (!d_match_trie.getMatches(bn, &d_ssenm)) + { + keep = false; + Trace("cr-filter") << "...redundant (matchable)" << std::endl; + // regardless, would help to remember it + registerRelevantPair(n, eq_n); + } + } + + if (keep) + { + return false; + } + if (Trace.isOn("sygus-rr-filter")) + { + Printer* p = Printer::getPrinter(options::outputLanguage()); + std::stringstream ss; + ss << "(redundant-rewrite "; + p->toStreamSygus(ss, n); + ss << " "; + p->toStreamSygus(ss, eq_n); + ss << ")"; + Trace("sygus-rr-filter") << ss.str() << std::endl; + } + return true; +} + +void CandidateRewriteFilter::registerRelevantPair(Node n, Node eq_n) +{ + Node bn = n; + Node beq_n = eq_n; + if (d_use_sygus_type) + { + bn = d_tds->sygusToBuiltin(n); + beq_n = d_tds->sygusToBuiltin(eq_n); + } + // ----- check rewriting redundancy + if (d_drewrite != nullptr) + { + Trace("cr-filter-debug") << "Add rewrite pair..." << std::endl; + Assert(!d_drewrite->areEqual(bn, beq_n)); + d_drewrite->addRewrite(bn, beq_n); + } + if (options::sygusRewSynthFilterMatch()) + { + // add to match information + for (unsigned r = 0; r < 2; r++) + { + Node t = r == 0 ? bn : beq_n; + Node to = r == 0 ? beq_n : bn; + // insert in match trie if first time + if (d_pairs.find(t) == d_pairs.end()) + { + Trace("crf-match") << "CRF add term : " << t << std::endl; + d_match_trie.addTerm(t); + } + d_pairs[t].insert(to); + } + } +} + +bool CandidateRewriteFilter::notify(Node s, + Node n, + std::vector& vars, + std::vector& subs) +{ + Assert(!d_curr_pair_rhs.isNull()); + std::map >::iterator it = + d_pairs.find(n); + if (Trace.isOn("crf-match")) + { + Trace("crf-match") << " " << s << " matches " << n + << " under:" << std::endl; + for (unsigned i = 0, size = vars.size(); i < size; i++) + { + Trace("crf-match") << " " << vars[i] << " -> " << subs[i] << std::endl; + // TODO (#1923) ensure that we use an internal representation to + // ensure polymorphism is handled correctly + Assert(vars[i].getType().isComparableTo(subs[i].getType())); + } + } + Assert(it != d_pairs.end()); + for (const Node& nr : it->second) + { + Node nrs = + nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + bool areEqual = (nrs == d_curr_pair_rhs); + if (!areEqual && d_drewrite != nullptr) + { + // if dynamic rewriter is available, consult it + areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs); + } + if (areEqual) + { + Trace("crf-match") << "*** Match, current pair: " << std::endl; + Trace("crf-match") << " (" << s << ", " << d_curr_pair_rhs << ")" + << std::endl; + Trace("crf-match") << "is an instance of previous pair:" << std::endl; + Trace("crf-match") << " (" << n << ", " << nr << ")" << std::endl; + return false; + } + } + return true; +} + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/quantifiers/candidate_rewrite_filter.h b/src/theory/quantifiers/candidate_rewrite_filter.h new file mode 100644 index 000000000..9a09680cc --- /dev/null +++ b/src/theory/quantifiers/candidate_rewrite_filter.h @@ -0,0 +1,218 @@ +/********************* */ +/*! \file candidate_rewrite_filter.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implements techniques for candidate rewrite rule filtering, which + ** filters the output of --sygus-rr-synth. The classes in this file implement + ** filtering based on congruence, variable ordering, and matching. + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H +#define __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H + +#include +#include "theory/quantifiers/dynamic_rewrite.h" +#include "theory/quantifiers/sygus/term_database_sygus.h" +#include "theory/quantifiers/sygus_sampler.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** A virtual class for notifications regarding matches. */ +class NotifyMatch +{ + public: + virtual ~NotifyMatch() {} + /** + * A notification that s is equal to n * { vars -> subs }. This function + * should return false if we do not wish to be notified of further matches. + */ + virtual bool notify(Node s, + Node n, + std::vector& vars, + std::vector& subs) = 0; +}; + +/** + * A trie (discrimination tree) storing a set of terms S, that can be used to + * query, for a given term t, all terms s from S that are matchable with t, + * that is s*sigma = t for some substitution sigma. + */ +class MatchTrie +{ + public: + /** Get matches + * + * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this + * trie that is matchable with n where s = n * { vars -> subs } for some + * vars, subs. This function returns false if one of these calls to notify + * returns false. + */ + bool getMatches(Node n, NotifyMatch* ntm); + /** Adds node n to this trie */ + void addTerm(Node n); + /** Clear this trie */ + void clear(); + + private: + /** + * The children of this node in the trie. Terms t are indexed by a + * depth-first (right to left) traversal on its subterms, where the + * top-symbol of t is indexed by: + * - (operator, #children) if t has an operator, or + * - (t, 0) if t does not have an operator. + */ + std::map > d_children; + /** The set of variables in the domain of d_children */ + std::vector d_vars; + /** The data of this node in the trie */ + Node d_data; +}; + +/** candidate rewrite filter + * + * This class is responsible for various filtering techniques for candidate + * rewrite rules, including: + * (1) filtering based on variable ordering, + * (2) filtering based on congruence, + * (3) filtering based on matching. + * For details, see Reynolds et al "Rewrites for SMT Solvers using Syntax-Guided + * Enumeration", SMT 2018. + * + * In the following, we assume that the registerRelevantPair method of this + * class been called for some pairs of terms. For each such call to + * registerRelevantPair( t, s ), we say that (t,s) and (s,t) are "relevant + * pairs". A relevant pair ( t, s ) typically corresponds to a (candidate) + * rewrite t = s. + */ +class CandidateRewriteFilter +{ + public: + CandidateRewriteFilter(); + + /** initialize + * + * Initializes this class, ss is the sygus sampler that this class is + * filtering rewrite rule pairs for, and tds is a point to the sygus term + * database utility class. + * + * If useSygusType is false, this means that the terms in filterPair and + * registerRelevantPair calls should be interpreted as themselves. Otherwise, + * if useSygusType is true, these terms should be interpreted as their + * analog according to the deep embedding. + */ + void initialize(SygusSampler* ss, TermDbSygus* tds, bool useSygusType); + /** filter pair + * + * This method returns true if the pair (n, eq_n) should be filtered. If it + * is not filtered, then the caller may choose to call + * registerRelevantPair(n, eq_n) below, although it may not, say if it finds + * another reason to discard the pair. + * + * If this method returns false, then for all previous relevant pairs + * ( a, eq_a ), we have that n = eq_n is not an instance of a = eq_a + * modulo symmetry of equality, nor is n = eq_n derivable from the set of + * all previous relevant pairs. The latter is determined by the d_drewrite + * utility. For example: + * [1] ( t+0, t ) and ( x+0, x ) + * will not both be relevant pairs of this function since t+0=t is + * an instance of x+0=x. + * [2] ( s, t ) and ( s+0, t+0 ) + * will not both be relevant pairs of this function since s+0=t+0 is + * derivable from s=t. + * These two criteria may be combined, for example: + * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s ) + * are relevant pairs, since t+0 is an instance of x+0 where + * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair + * above (t+s = s). + */ + bool filterPair(Node n, Node eq_n); + /** register relevant pair + * + * This should be called after filterPair( n, eq_n ) returns false. + * This registers ( n, eq_n ) as a relevant pair with this class. + */ + void registerRelevantPair(Node n, Node eq_n); + + private: + /** pointer to the sygus sampler that this class is filtering rewrites for */ + SygusSampler* d_ss; + /** pointer to the sygus term database, used if d_use_sygus_type is true */ + TermDbSygus* d_tds; + /** whether we are registering sygus terms with this class */ + bool d_use_sygus_type; + + //----------------------------congruence filtering + /** a (dummy) user context, used for d_drewrite */ + context::UserContext d_fake_context; + /** dynamic rewriter class */ + std::unique_ptr d_drewrite; + //----------------------------end congruence filtering + + //----------------------------match filtering + /** + * Stores all relevant pairs returned by this sampler (see registerTerm). In + * detail, if (t,s) is a relevant pair, then t in d_pairs[s]. + */ + std::map > d_pairs; + /** Match trie storing all terms in the domain of d_pairs. */ + MatchTrie d_match_trie; + /** Notify class */ + class CandidateRewriteFilterNotifyMatch : public NotifyMatch + { + CandidateRewriteFilter& d_sse; + + public: + CandidateRewriteFilterNotifyMatch(CandidateRewriteFilter& sse) : d_sse(sse) + { + } + /** notify match */ + bool notify(Node n, + Node s, + std::vector& vars, + std::vector& subs) override + { + return d_sse.notify(n, s, vars, subs); + } + }; + /** Notify object used for reporting matches from d_match_trie */ + CandidateRewriteFilterNotifyMatch d_ssenm; + /** + * Stores the current right hand side of a pair we are considering. + * + * In more detail, in registerTerm, we are interested in whether a pair (s,t) + * is a relevant pair. We do this by: + * (1) Setting the node d_curr_pair_rhs to t, + * (2) Using d_match_trie, compute all terms s1...sn that match s. + * For each si, where s = si * sigma for some substitution sigma, we check + * whether t = ti * sigma for some previously relevant pair (si,ti), in + * which case (s,t) is an instance of (si,ti). + */ + Node d_curr_pair_rhs; + /** + * Called by the above class during d_match_trie.getMatches( s ), when we + * find that si = s * sigma, where si is a term that is stored in + * d_match_trie. + * + * This function returns false if ( s, d_curr_pair_rhs ) is an instance of + * previously relevant pair. + */ + bool notify(Node s, Node n, std::vector& vars, std::vector& subs); + //----------------------------end match filtering +}; + +} // namespace quantifiers +} // namespace theory +} // namespace CVC4 + +#endif /* __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H */ diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index e07f73540..ebd10c585 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -746,378 +746,6 @@ void SygusSampler::registerSygusType(TypeNode tn) } } -SygusSamplerExt::SygusSamplerExt() : d_drewrite(nullptr), d_ssenm(*this) {} - -void SygusSamplerExt::initializeSygus(TermDbSygus* tds, - Node f, - unsigned nsamples, - bool useSygusType) -{ - SygusSampler::initializeSygus(tds, f, nsamples, useSygusType); - d_pairs.clear(); - d_match_trie.clear(); -} - -void SygusSamplerExt::setDynamicRewriter(DynamicRewriter* dr) -{ - d_drewrite = dr; -} - -Node SygusSamplerExt::registerTerm(Node n, bool forceKeep) -{ - Node eq_n = SygusSampler::registerTerm(n, forceKeep); - if (eq_n == n) - { - // this is a unique term - return n; - } - Node bn = n; - Node beq_n = eq_n; - if (d_use_sygus_type) - { - bn = d_tds->sygusToBuiltin(n); - beq_n = d_tds->sygusToBuiltin(eq_n); - } - Trace("sygus-synth-rr") << "sygusSampleExt : " << bn << "..." << beq_n - << std::endl; - // whether we will keep this pair - bool keep = true; - - // ----- check ordering redundancy - if (options::sygusRewSynthFilterOrder()) - { - bool nor = isOrdered(bn); - bool eqor = isOrdered(beq_n); - Trace("sygus-synth-rr-debug") << "Ordered? : " << nor << " " << eqor - << std::endl; - if (eqor || nor) - { - // if only one is ordered, then we require that the ordered one's - // variables cannot be a strict subset of the variables of the other. - if (!eqor) - { - if (containsFreeVariables(beq_n, bn, true)) - { - keep = false; - } - else - { - // if the previous value stored was unordered, but n is - // ordered, we prefer n. Thus, we force its addition to the - // sampler database. - SygusSampler::registerTerm(n, true); - } - } - else if (!nor) - { - keep = !containsFreeVariables(bn, beq_n, true); - } - } - else - { - keep = false; - } - if (!keep) - { - Trace("sygus-synth-rr") << "...redundant (unordered)" << std::endl; - } - } - - // ----- check rewriting redundancy - if (keep && d_drewrite != nullptr) - { - Trace("sygus-synth-rr-debug") << "Check equal rewrite pair..." << std::endl; - if (d_drewrite->areEqual(bn, beq_n)) - { - // must be unique according to the dynamic rewriter - Trace("sygus-synth-rr") << "...redundant (rewritable)" << std::endl; - keep = false; - } - } - - if (options::sygusRewSynthFilterMatch()) - { - // ----- check matchable - // check whether the pair is matchable with a previous one - d_curr_pair_rhs = beq_n; - Trace("sse-match") << "SSE check matches : " << bn << " [rhs = " << beq_n - << "]..." << std::endl; - if (!d_match_trie.getMatches(bn, &d_ssenm)) - { - keep = false; - Trace("sygus-synth-rr") << "...redundant (matchable)" << std::endl; - // regardless, would help to remember it - registerRelevantPair(n, eq_n); - } - } - - if (keep) - { - return eq_n; - } - if (Trace.isOn("sygus-rr-filter")) - { - Printer* p = Printer::getPrinter(options::outputLanguage()); - std::stringstream ss; - ss << "(redundant-rewrite "; - p->toStreamSygus(ss, n); - ss << " "; - p->toStreamSygus(ss, eq_n); - ss << ")"; - Trace("sygus-rr-filter") << ss.str() << std::endl; - } - return Node::null(); -} - -void SygusSamplerExt::registerRelevantPair(Node n, Node eq_n) -{ - Node bn = n; - Node beq_n = eq_n; - if (d_use_sygus_type) - { - bn = d_tds->sygusToBuiltin(n); - beq_n = d_tds->sygusToBuiltin(eq_n); - } - // ----- check rewriting redundancy - if (d_drewrite != nullptr) - { - Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl; - Assert(!d_drewrite->areEqual(bn, beq_n)); - d_drewrite->addRewrite(bn, beq_n); - } - if (options::sygusRewSynthFilterMatch()) - { - // add to match information - for (unsigned r = 0; r < 2; r++) - { - Node t = r == 0 ? bn : beq_n; - Node to = r == 0 ? beq_n : bn; - // insert in match trie if first time - if (d_pairs.find(t) == d_pairs.end()) - { - Trace("sse-match") << "SSE add term : " << t << std::endl; - d_match_trie.addTerm(t); - } - d_pairs[t].insert(to); - } - } -} - -bool SygusSamplerExt::notify(Node s, - Node n, - std::vector& vars, - std::vector& subs) -{ - Assert(!d_curr_pair_rhs.isNull()); - std::map >::iterator it = - d_pairs.find(n); - if (Trace.isOn("sse-match")) - { - Trace("sse-match") << " " << s << " matches " << n - << " under:" << std::endl; - for (unsigned i = 0, size = vars.size(); i < size; i++) - { - Trace("sse-match") << " " << vars[i] << " -> " << subs[i] << std::endl; - // TODO (#1923) ensure that we use an internal representation to - // ensure polymorphism is handled correctly - Assert(vars[i].getType().isComparableTo(subs[i].getType())); - } - } - Assert(it != d_pairs.end()); - for (const Node& nr : it->second) - { - Node nrs = - nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); - bool areEqual = (nrs == d_curr_pair_rhs); - if (!areEqual && d_drewrite != nullptr) - { - // if dynamic rewriter is available, consult it - areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs); - } - if (areEqual) - { - Trace("sse-match") << "*** Match, current pair: " << std::endl; - Trace("sse-match") << " (" << s << ", " << d_curr_pair_rhs << ")" - << std::endl; - Trace("sse-match") << "is an instance of previous pair:" << std::endl; - Trace("sse-match") << " (" << n << ", " << nr << ")" << std::endl; - return false; - } - } - return true; -} - -bool MatchTrie::getMatches(Node n, NotifyMatch* ntm) -{ - std::vector vars; - std::vector subs; - std::map smap; - - std::vector > visit; - std::vector visit_trie; - std::vector visit_var_index; - std::vector visit_bound_var; - - visit.push_back(std::vector{n}); - visit_trie.push_back(this); - visit_var_index.push_back(-1); - visit_bound_var.push_back(false); - while (!visit.empty()) - { - std::vector cvisit = visit.back(); - MatchTrie* curr = visit_trie.back(); - if (cvisit.empty()) - { - Assert(n - == curr->d_data.substitute( - vars.begin(), vars.end(), subs.begin(), subs.end())); - Trace("sse-match-debug") << "notify : " << curr->d_data << std::endl; - if (!ntm->notify(n, curr->d_data, vars, subs)) - { - return false; - } - visit.pop_back(); - visit_trie.pop_back(); - visit_var_index.pop_back(); - visit_bound_var.pop_back(); - } - else - { - Node cn = cvisit.back(); - Trace("sse-match-debug") - << "traverse : " << cn << " at depth " << visit.size() << std::endl; - unsigned index = visit.size() - 1; - int vindex = visit_var_index[index]; - if (vindex == -1) - { - if (!cn.isVar()) - { - Node op = cn.hasOperator() ? cn.getOperator() : cn; - unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0; - std::map::iterator itu = - curr->d_children[op].find(nchild); - if (itu != curr->d_children[op].end()) - { - // recurse on the operator or self - cvisit.pop_back(); - if (cn.hasOperator()) - { - for (const Node& cnc : cn) - { - cvisit.push_back(cnc); - } - } - Trace("sse-match-debug") << "recurse op : " << op << std::endl; - visit.push_back(cvisit); - visit_trie.push_back(&itu->second); - visit_var_index.push_back(-1); - visit_bound_var.push_back(false); - } - } - visit_var_index[index]++; - } - else - { - // clean up if we previously bound a variable - if (visit_bound_var[index]) - { - Assert(!vars.empty()); - smap.erase(vars.back()); - vars.pop_back(); - subs.pop_back(); - visit_bound_var[index] = false; - } - - if (vindex == static_cast(curr->d_vars.size())) - { - Trace("sse-match-debug") - << "finished checking " << curr->d_vars.size() - << " variables at depth " << visit.size() << std::endl; - // finished - visit.pop_back(); - visit_trie.pop_back(); - visit_var_index.pop_back(); - visit_bound_var.pop_back(); - } - else - { - Trace("sse-match-debug") << "check variable #" << vindex - << " at depth " << visit.size() << std::endl; - Assert(vindex < static_cast(curr->d_vars.size())); - // recurse on variable? - Node var = curr->d_vars[vindex]; - bool recurse = true; - // check if it is already bound - std::map::iterator its = smap.find(var); - if (its != smap.end()) - { - if (its->second != cn) - { - recurse = false; - } - } - else - { - vars.push_back(var); - subs.push_back(cn); - smap[var] = cn; - visit_bound_var[index] = true; - } - if (recurse) - { - Trace("sse-match-debug") << "recurse var : " << var << std::endl; - cvisit.pop_back(); - visit.push_back(cvisit); - visit_trie.push_back(&curr->d_children[var][0]); - visit_var_index.push_back(-1); - visit_bound_var.push_back(false); - } - visit_var_index[index]++; - } - } - } - } - return true; -} - -void MatchTrie::addTerm(Node n) -{ - std::vector visit; - visit.push_back(n); - MatchTrie* curr = this; - while (!visit.empty()) - { - Node cn = visit.back(); - visit.pop_back(); - if (cn.hasOperator()) - { - curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]); - for (const Node& cnc : cn) - { - visit.push_back(cnc); - } - } - else - { - if (cn.isVar() - && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn) - == curr->d_vars.end()) - { - curr->d_vars.push_back(cn); - } - curr = &(curr->d_children[cn][0]); - } - } - curr->d_data = n; -} - -void MatchTrie::clear() -{ - d_children.clear(); - d_vars.clear(); - d_data = Node::null(); -} - } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index 290a8b17d..0134b3a86 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -19,7 +19,6 @@ #include #include "theory/evaluator.h" -#include "theory/quantifiers/dynamic_rewrite.h" #include "theory/quantifiers/lazy_trie.h" #include "theory/quantifiers/sygus/term_database_sygus.h" #include "theory/quantifiers/term_enumeration.h" @@ -28,7 +27,6 @@ namespace CVC4 { namespace theory { namespace quantifiers { - /** SygusSampler * * This class can be used to test whether two expressions are equivalent @@ -124,7 +122,7 @@ class SygusSampler : public LazyTrieEvaluator */ int getDiffSamplePointIndex(Node a, Node b); - protected: + //--------------------------queries about terms /** is contiguous * * This returns whether n's free variables (terms occurring in the range of @@ -149,6 +147,7 @@ class SygusSampler : public LazyTrieEvaluator * occur in the range d_type_vars. */ bool containsFreeVariables(Node a, Node b, bool strict = false); + //--------------------------end queries about terms protected: /** sygus term database of d_qe */ @@ -286,167 +285,6 @@ class SygusSampler : public LazyTrieEvaluator void registerSygusType(TypeNode tn); }; -/** A virtual class for notifications regarding matches. */ -class NotifyMatch -{ - public: - virtual ~NotifyMatch() {} - - /** - * A notification that s is equal to n * { vars -> subs }. This function - * should return false if we do not wish to be notified of further matches. - */ - virtual bool notify(Node s, - Node n, - std::vector& vars, - std::vector& subs) = 0; -}; - -/** - * A trie (discrimination tree) storing a set of terms S, that can be used to - * query, for a given term t, all terms from S that are matchable with t. - */ -class MatchTrie -{ - public: - /** Get matches - * - * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this - * trie that is matchable with n where s = n * { vars -> subs } for some - * vars, subs. This function returns false if one of these calls to notify - * returns false. - */ - bool getMatches(Node n, NotifyMatch* ntm); - /** Adds node n to this trie */ - void addTerm(Node n); - /** Clear this trie */ - void clear(); - - private: - /** - * The children of this node in the trie. Terms t are indexed by a - * depth-first (right to left) traversal on its subterms, where the - * top-symbol of t is indexed by: - * - (operator, #children) if t has an operator, or - * - (t, 0) if t does not have an operator. - */ - std::map > d_children; - /** The set of variables in the domain of d_children */ - std::vector d_vars; - /** The data of this node in the trie */ - Node d_data; -}; - -/** Version of the above class with some additional features */ -class SygusSamplerExt : public SygusSampler -{ - public: - SygusSamplerExt(); - /** initialize */ - void initializeSygus(TermDbSygus* tds, - Node f, - unsigned nsamples, - bool useSygusType) override; - /** set dynamic rewriter - * - * This tells this class to use the dynamic rewriter object dr. This utility - * is used to query whether pairs of terms are already entailed to be - * equal based on previous rewrite rules. - */ - void setDynamicRewriter(DynamicRewriter* dr); - - /** register term n with this sampler database - * - * For each call to registerTerm( t, ... ) that returns s, we say that - * (t,s) and (s,t) are "relevant pairs". - * - * This returns either null, or a term ret with the same guarantees as - * SygusSampler::registerTerm with the additional guarantee - * that for all previous relevant pairs ( n', nret' ), - * we have that n = ret is not an instance of n' = ret' - * modulo symmetry of equality, nor is n = ret derivable from the set of - * all previous relevant pairs. The latter is determined by the d_drewrite - * utility. For example: - * [1] ( t+0, t ) and ( x+0, x ) - * will not both be relevant pairs of this function since t+0=t is - * an instance of x+0=x. - * [2] ( s, t ) and ( s+0, t+0 ) - * will not both be relevant pairs of this function since s+0=t+0 is - * derivable from s=t. - * These two criteria may be combined, for example: - * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s ) - * are relevant pairs, since t+0 is an instance of x+0 where - * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair - * above (t+s = s). - * - * If this function returns null, then n is equivalent to a previously - * registered term ret, and the equality ( n, ret ) is either an instance - * of a previous relevant pair ( n', ret' ), or n = ret is derivable - * from the set of all previous relevant pairs based on the - * d_drewrite utility, or is an instance of a previous pair - */ - Node registerTerm(Node n, bool forceKeep = false) override; - /** register relevant pair - * - * This should be called after registerTerm( n ) returns eq_n. - * This registers ( n, eq_n ) as a relevant pair with this class. - */ - void registerRelevantPair(Node n, Node eq_n); - - private: - /** pointer to the dynamic rewriter class */ - DynamicRewriter* d_drewrite; - - //----------------------------match filtering - /** - * Stores all relevant pairs returned by this sampler (see registerTerm). In - * detail, if (t,s) is a relevant pair, then t in d_pairs[s]. - */ - std::map > d_pairs; - /** Match trie storing all terms in the domain of d_pairs. */ - MatchTrie d_match_trie; - /** Notify class */ - class SygusSamplerExtNotifyMatch : public NotifyMatch - { - SygusSamplerExt& d_sse; - - public: - SygusSamplerExtNotifyMatch(SygusSamplerExt& sse) : d_sse(sse) {} - /** notify match */ - bool notify(Node n, - Node s, - std::vector& vars, - std::vector& subs) override - { - return d_sse.notify(n, s, vars, subs); - } - }; - /** Notify object used for reporting matches from d_match_trie */ - SygusSamplerExtNotifyMatch d_ssenm; - /** - * Stores the current right hand side of a pair we are considering. - * - * In more detail, in registerTerm, we are interested in whether a pair (s,t) - * is a relevant pair. We do this by: - * (1) Setting the node d_curr_pair_rhs to t, - * (2) Using d_match_trie, compute all terms s1...sn that match s. - * For each si, where s = si * sigma for some substitution sigma, we check - * whether t = ti * sigma for some previously relevant pair (si,ti), in - * which case (s,t) is an instance of (si,ti). - */ - Node d_curr_pair_rhs; - /** - * Called by the above class during d_match_trie.getMatches( s ), when we - * find that si = s * sigma, where si is a term that is stored in - * d_match_trie. - * - * This function returns false if ( s, d_curr_pair_rhs ) is an instance of - * previously relevant pair. - */ - bool notify(Node s, Node n, std::vector& vars, std::vector& subs); - //----------------------------end match filtering -}; - } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */