From: Andrew Reynolds Date: Tue, 27 Mar 2018 16:53:49 +0000 (-0500) Subject: Filter candidate rewrites based on matching (#1682) X-Git-Tag: cvc5-1.0.0~5201 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=9dcaaeba4880a8f145df00289ff1b092a7e3dd47;p=cvc5.git Filter candidate rewrites based on matching (#1682) --- diff --git a/src/theory/quantifiers/dynamic_rewrite.cpp b/src/theory/quantifiers/dynamic_rewrite.cpp index 3462a4d10..cb7379910 100644 --- a/src/theory/quantifiers/dynamic_rewrite.cpp +++ b/src/theory/quantifiers/dynamic_rewrite.cpp @@ -66,6 +66,20 @@ bool DynamicRewriter::addRewrite(Node a, Node b) return true; } +bool DynamicRewriter::areEqual(Node a, Node b) +{ + if (a == b) + { + return true; + } + // add to the equality engine + Node ai = toInternal(a); + Node bi = toInternal(b); + d_equalityEngine.addTerm(ai); + d_equalityEngine.addTerm(bi); + return d_equalityEngine.areEqual(ai, bi); +} + Node DynamicRewriter::toInternal(Node a) { std::map::iterator it = d_term_to_internal.find(a); diff --git a/src/theory/quantifiers/dynamic_rewrite.h b/src/theory/quantifiers/dynamic_rewrite.h index 2b5464151..388173829 100644 --- a/src/theory/quantifiers/dynamic_rewrite.h +++ b/src/theory/quantifiers/dynamic_rewrite.h @@ -63,6 +63,10 @@ class DynamicRewriter * a = b based on the previous equalities it has seen. */ bool addRewrite(Node a, Node b); + /** + * Check whether this class knows that the equality a = b holds. + */ + bool areEqual(Node a, Node b); private: /** pointer to the quantifiers engine */ diff --git a/src/theory/quantifiers/sygus_sampler.cpp b/src/theory/quantifiers/sygus_sampler.cpp index afbdc42e1..99494657f 100644 --- a/src/theory/quantifiers/sygus_sampler.cpp +++ b/src/theory/quantifiers/sygus_sampler.cpp @@ -678,6 +678,8 @@ void SygusSampler::registerSygusType(TypeNode tn) } } +SygusSamplerExt::SygusSamplerExt() : d_ssenm(*this) {} + void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe, Node f, unsigned nsamples, @@ -691,6 +693,8 @@ void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe, ss << f; d_drewrite = std::unique_ptr(new DynamicRewriter(ss.str(), qe)); + d_pairs.clear(); + d_match_trie.clear(); } Node SygusSamplerExt::registerTerm(Node n, bool forceKeep) @@ -700,6 +704,7 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep) << std::endl; if (eq_n == n) { + // this is a unique term return n; } Node bn = n; @@ -709,63 +714,268 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep) bn = d_tds->sygusToBuiltin(n); beq_n = d_tds->sygusToBuiltin(eq_n); } - // one of eq_n or n must be ordered - bool eqor = isOrdered(beq_n); - bool nor = isOrdered(bn); - Trace("sygus-synth-rr-debug") - << "Ordered? : " << nor << " " << eqor << std::endl; - bool isUnique = false; - if (eqor || nor) + // whether we will keep this pair + bool keep = true; + + // ----- check matchable + // check whether the pair is matchable with a previous one + d_curr_pair_rhs = beq_n; + Trace("sse-match") << "SSE check matches : " << n << " [rhs = " << eq_n + << "]..." << std::endl; + if (!d_match_trie.getMatches(bn, &d_ssenm)) { - isUnique = true; - // if only one is ordered, then the ordered one must contain the - // free variables of the other - if (!eqor) - { - isUnique = containsFreeVariables(bn, beq_n); - } - else if (!nor) - { - isUnique = containsFreeVariables(beq_n, bn); - } + keep = false; + Trace("sygus-synth-rr-debug") << "...redundant (matchable)" << std::endl; } - Trace("sygus-synth-rr-debug") << "AlphaEq unique: " << isUnique << std::endl; - bool rewRedundant = false; + + // ----- check rewriting redundancy if (d_drewrite != nullptr) { - Trace("sygus-synth-rr-debug") << "Add rewrite..." << std::endl; + Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl; if (!d_drewrite->addRewrite(bn, beq_n)) { - rewRedundant = isUnique; // must be unique according to the dynamic rewriter - isUnique = false; + keep = false; + Trace("sygus-synth-rr-debug") << "...redundant (rewritable)" << std::endl; } } - Trace("sygus-synth-rr-debug") << "Rewrite unique: " << isUnique << std::endl; - if (isUnique) + if (keep) { - // if the previous value stored was unordered, but this is - // ordered, we prefer this one. Thus, we force its addition to the - // sampler database. - if (!eqor) + // add to match information + for (unsigned r = 0; r < 2; r++) { - SygusSampler::registerTerm(n, true); + Node t = r == 0 ? bn : beq_n; + Node to = r == 0 ? beq_n : bn; + // insert in match trie if first time + if (d_pairs.find(t) == d_pairs.end()) + { + Trace("sse-match") << "SSE add term : " << t << std::endl; + d_match_trie.addTerm(t); + } + d_pairs[t].insert(to); } return eq_n; } else if (Trace.isOn("sygus-synth-rr")) { - Trace("sygus-synth-rr") << "Redundant rewrite : " << eq_n << " " << n; - if (rewRedundant) - { - Trace("sygus-synth-rr") << " (by rewriting)"; - } + Trace("sygus-synth-rr") << "Redundant pair : " << eq_n << " " << n; Trace("sygus-synth-rr") << std::endl; } return Node::null(); } +bool SygusSamplerExt::notify(Node s, + Node n, + std::vector& vars, + std::vector& subs) +{ + Assert(!d_curr_pair_rhs.isNull()); + std::map >::iterator it = + d_pairs.find(n); + if (Trace.isOn("sse-match")) + { + Trace("sse-match") << " " << s << " matches " << n + << " under:" << std::endl; + for (unsigned i = 0, size = vars.size(); i < size; i++) + { + Trace("sse-match") << " " << vars[i] << " -> " << subs[i] << std::endl; + } + } + Assert(it != d_pairs.end()); + for (const Node& nr : it->second) + { + Node nrs = + nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + bool areEqual = (nrs == d_curr_pair_rhs); + if (!areEqual && d_drewrite != nullptr) + { + // if dynamic rewriter is available, consult it + areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs); + } + if (areEqual) + { + Trace("sse-match") << "*** Match, current pair: " << std::endl; + Trace("sse-match") << " (" << s << ", " << d_curr_pair_rhs << ")" + << std::endl; + Trace("sse-match") << "is an instance of previous pair:" << std::endl; + Trace("sse-match") << " (" << n << ", " << nr << ")" << std::endl; + return false; + } + } + return true; +} + +bool MatchTrie::getMatches(Node n, NotifyMatch* ntm) +{ + std::vector vars; + std::vector subs; + std::map smap; + + std::vector > visit; + std::vector visit_trie; + std::vector visit_var_index; + std::vector visit_bound_var; + + visit.push_back(std::vector{n}); + visit_trie.push_back(this); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + while (!visit.empty()) + { + std::vector cvisit = visit.back(); + MatchTrie* curr = visit_trie.back(); + if (cvisit.empty()) + { + Assert(n + == curr->d_data.substitute( + vars.begin(), vars.end(), subs.begin(), subs.end())); + Trace("sse-match-debug") << "notify : " << curr->d_data << std::endl; + if (!ntm->notify(n, curr->d_data, vars, subs)) + { + return false; + } + visit.pop_back(); + visit_trie.pop_back(); + visit_var_index.pop_back(); + visit_bound_var.pop_back(); + } + else + { + Node cn = cvisit.back(); + Trace("sse-match-debug") + << "traverse : " << cn << " at depth " << visit.size() << std::endl; + unsigned index = visit.size() - 1; + int vindex = visit_var_index[index]; + if (vindex == -1) + { + if (!cn.isVar()) + { + Node op = cn.hasOperator() ? cn.getOperator() : cn; + unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0; + std::map::iterator itu = + curr->d_children[op].find(nchild); + if (itu != curr->d_children[op].end()) + { + // recurse on the operator or self + cvisit.pop_back(); + if (cn.hasOperator()) + { + for (const Node& cnc : cn) + { + cvisit.push_back(cnc); + } + } + Trace("sse-match-debug") << "recurse op : " << op << std::endl; + visit.push_back(cvisit); + visit_trie.push_back(&itu->second); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + } + } + visit_var_index[index]++; + } + else + { + // clean up if we previously bound a variable + if (visit_bound_var[index]) + { + Assert(!vars.empty()); + smap.erase(vars.back()); + vars.pop_back(); + subs.pop_back(); + } + + if (vindex == static_cast(curr->d_vars.size())) + { + Trace("sse-match-debug") + << "finished checking " << curr->d_vars.size() + << " variables at depth " << visit.size() << std::endl; + // finished + visit.pop_back(); + visit_trie.pop_back(); + visit_var_index.pop_back(); + visit_bound_var.pop_back(); + } + else + { + Trace("sse-match-debug") << "check variable #" << vindex + << " at depth " << visit.size() << std::endl; + Assert(vindex < static_cast(curr->d_vars.size())); + // recurse on variable? + Node var = curr->d_vars[vindex]; + bool recurse = true; + // check if it is already bound + std::map::iterator its = smap.find(var); + if (its != smap.end()) + { + if (its->second != cn) + { + recurse = false; + } + } + else + { + vars.push_back(var); + subs.push_back(cn); + smap[var] = cn; + visit_bound_var[index] = true; + } + if (recurse) + { + Trace("sse-match-debug") << "recurse var : " << var << std::endl; + cvisit.pop_back(); + visit.push_back(cvisit); + visit_trie.push_back(&curr->d_children[var][0]); + visit_var_index.push_back(-1); + visit_bound_var.push_back(false); + } + visit_var_index[index]++; + } + } + } + } + return true; +} + +void MatchTrie::addTerm(Node n) +{ + std::vector visit; + visit.push_back(n); + MatchTrie* curr = this; + while (!visit.empty()) + { + Node cn = visit.back(); + visit.pop_back(); + if (cn.hasOperator()) + { + curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]); + for (const Node& cnc : cn) + { + visit.push_back(cnc); + } + } + else + { + if (cn.isVar() + && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn) + == curr->d_vars.end()) + { + curr->d_vars.push_back(cn); + } + curr = &(curr->d_children[cn][0]); + } + } + curr->d_data = n; +} + +void MatchTrie::clear() +{ + d_children.clear(); + d_vars.clear(); + d_data = Node::null(); +} + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_sampler.h b/src/theory/quantifiers/sygus_sampler.h index 4bc10075d..fa0d670d2 100644 --- a/src/theory/quantifiers/sygus_sampler.h +++ b/src/theory/quantifiers/sygus_sampler.h @@ -340,42 +340,149 @@ class SygusSampler : public LazyTrieEvaluator void registerSygusType(TypeNode tn); }; +/** A virtual class for notifications regarding matches. */ +class NotifyMatch +{ + public: + /** + * A notification that s is equal to n * { vars -> subs }. This function + * should return false if we do not wish to be notified of further matches. + */ + virtual bool notify(Node s, + Node n, + std::vector& vars, + std::vector& subs) = 0; +}; + +/** + * A trie (discrimination tree) storing a set of terms S, that can be used to + * query, for a given term t, all terms from S that are matchable with t. + */ +class MatchTrie +{ + public: + /** Get matches + * + * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this + * trie that is matchable with n where s = n * { vars -> subs } for some + * vars, subs. This function returns false if one of these calls to notify + * returns false. + */ + bool getMatches(Node n, NotifyMatch* ntm); + /** Adds node n to this trie */ + void addTerm(Node n); + /** Clear this trie */ + void clear(); + + private: + /** + * The children of this node in the trie. Terms t are indexed by a + * depth-first (right to left) traversal on its subterms, where the + * top-symbol of t is indexed by: + * - (operator, #children) if t has an operator, or + * - (t, 0) if t does not have an operator. + */ + std::map > d_children; + /** The set of variables in the domain of d_children */ + std::vector d_vars; + /** The data of this node in the trie */ + Node d_data; +}; + /** Version of the above class with some additional features */ class SygusSamplerExt : public SygusSampler { public: + SygusSamplerExt(); /** initialize extended */ void initializeSygusExt(QuantifiersEngine* qe, Node f, unsigned nsamples, bool useSygusType); /** register term n with this sampler database + * + * For each call to registerTerm( t, ... ) that returns s, we say that + * (t,s) and (s,t) are "relevant pairs". * * This returns either null, or a term ret with the same guarantees as * SygusSampler::registerTerm with the additional guarantee - * that for all ret' returned by a previous call to registerTerm( n' ), - * we have that n = ret is not alpha-equivalent to n' = ret' + * that for all previous relevant pairs ( n', nret' ), + * we have that n = ret is not an instance of n' = ret' * modulo symmetry of equality, nor is n = ret derivable from the set of - * all previous input/output pairs based on the d_drewrite utility. - * For example, - * (t+0), t and (s+0), s - * will not both be input/output pairs of this function since t+0=t is - * alpha-equivalent to s+0=s. - * s, t and s+0, t+0 - * will not both be input/output pairs of this function since s+0=t+0 is + * all previous relevant pairs. The latter is determined by the d_drewrite + * utility. For example: + * [1] ( t+0, t ) and ( x+0, x ) + * will not both be relevant pairs of this function since t+0=t is + * an instance of x+0=x. + * [2] ( s, t ) and ( s+0, t+0 ) + * will not both be relevant pairs of this function since s+0=t+0 is * derivable from s=t. + * These two criteria may be combined, for example: + * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s ) + * are relevant pairs, since t+0 is an instance of x+0 where + * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair + * above (t+s = s). * * If this function returns null, then n is equivalent to a previously - * registered term ret, and the equality n = ret is either alpha-equivalent - * to a previous input/output pair n' = ret', or n = ret is derivable - * from the set of all previous input/output pairs based on the - * d_drewrite utility. + * registered term ret, and the equality ( n, ret ) is either an instance + * of a previous relevant pair ( n', ret' ), or n = ret is derivable + * from the set of all previous relevant pairs based on the + * d_drewrite utility, or is an instance of a previous pair */ Node registerTerm(Node n, bool forceKeep = false) override; private: /** dynamic rewriter class */ std::unique_ptr d_drewrite; + + //----------------------------match filtering + /** + * Stores all relevant pairs returned by this sampler (see registerTerm). In + * detail, if (t,s) is a relevant pair, then t in d_pairs[s]. + */ + std::map > d_pairs; + /** Match trie storing all terms in the domain of d_pairs. */ + MatchTrie d_match_trie; + /** Notify class */ + class SygusSamplerExtNotifyMatch : public NotifyMatch + { + SygusSamplerExt& d_sse; + + public: + SygusSamplerExtNotifyMatch(SygusSamplerExt& sse) : d_sse(sse) {} + /** notify match */ + bool notify(Node n, + Node s, + std::vector& vars, + std::vector& subs) override + { + return d_sse.notify(n, s, vars, subs); + } + }; + /** Notify object used for reporting matches from d_match_trie */ + SygusSamplerExtNotifyMatch d_ssenm; + /** + * Stores the current right hand side of a pair we are considering. + * + * In more detail, in registerTerm, we are interested in whether a pair (s,t) + * is a relevant pair. We do this by: + * (1) Setting the node d_curr_pair_rhs to t, + * (2) Using d_match_trie, compute all terms s1...sn that match s. + * For each si, where s = si * sigma for some substitution sigma, we check + * whether t = ti * sigma for some previously relevant pair (si,ti), in + * which case (s,t) is an instance of (si,ti). + */ + Node d_curr_pair_rhs; + /** + * Called by the above class during d_match_trie.getMatches( s ), when we + * find that si = s * sigma, where si is a term that is stored in + * d_match_trie. + * + * This function returns false if ( s, d_curr_pair_rhs ) is an instance of + * previously relevant pair. + */ + bool notify(Node s, Node n, std::vector& vars, std::vector& subs); + //----------------------------end match filtering }; } /* CVC4::theory::quantifiers namespace */