Reorganize candidate rewrite rule filtering (#2116)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 4 Jul 2018 13:31:14 +0000 (14:31 +0100)
committerAina Niemetz <aina.niemetz@gmail.com>
Wed, 4 Jul 2018 13:31:14 +0000 (06:31 -0700)
src/Makefile.am
src/theory/quantifiers/candidate_rewrite_database.cpp
src/theory/quantifiers/candidate_rewrite_database.h
src/theory/quantifiers/candidate_rewrite_filter.cpp [new file with mode: 0644]
src/theory/quantifiers/candidate_rewrite_filter.h [new file with mode: 0644]
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index b5da564cf1fdba5984fda89810185b2a4e7b1b71..917fc6ef37c41ec1a09d3201d16c417e24cc7acc 100644 (file)
@@ -393,6 +393,8 @@ libcvc4_la_SOURCES = \
        theory/quantifiers/bv_inverter.h \
        theory/quantifiers/candidate_rewrite_database.cpp \
        theory/quantifiers/candidate_rewrite_database.h \
+       theory/quantifiers/candidate_rewrite_filter.cpp \
+       theory/quantifiers/candidate_rewrite_filter.h \
        theory/quantifiers/cegqi/ceg_instantiator.cpp \
        theory/quantifiers/cegqi/ceg_instantiator.h \
        theory/quantifiers/cegqi/ceg_arith_instantiator.cpp \
index 9bbb88699f1c8f4207c7222ab250be0cca2141bf..a5a35f89d8863b533a36999ca7077749424c1d24 100644 (file)
@@ -32,25 +32,12 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-// the number of d_drewrite objects we have allocated (to avoid name conflicts)
-static unsigned drewrite_counter = 0;
-
 CandidateRewriteDatabase::CandidateRewriteDatabase()
     : d_qe(nullptr),
       d_tds(nullptr),
       d_ext_rewrite(nullptr),
       d_using_sygus(false)
 {
-  if (options::sygusRewSynthFilterCong())
-  {
-    // initialize the dynamic rewriter
-    std::stringstream ss;
-    ss << "_dyn_rewriter_" << drewrite_counter;
-    drewrite_counter++;
-    d_drewrite = std::unique_ptr<DynamicRewriter>(
-        new DynamicRewriter(ss.str(), &d_fake_context));
-    d_sampler.setDynamicRewriter(d_drewrite.get());
-  }
 }
 void CandidateRewriteDatabase::initialize(ExtendedRewriter* er,
                                           TypeNode tn,
@@ -65,6 +52,7 @@ void CandidateRewriteDatabase::initialize(ExtendedRewriter* er,
   d_tds = nullptr;
   d_ext_rewrite = er;
   d_sampler.initialize(tn, vars, nsamples, unique_type_ids);
+  d_crewrite_filter.initialize(&d_sampler, nullptr, false);
 }
 
 void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe,
@@ -81,6 +69,7 @@ void CandidateRewriteDatabase::initializeSygus(QuantifiersEngine* qe,
   d_tds = d_qe->getTermDatabaseSygus();
   d_ext_rewrite = d_tds->getExtRewriter();
   d_sampler.initializeSygus(d_tds, f, nsamples, useSygusType);
+  d_crewrite_filter.initialize(&d_sampler, d_tds, true);
 }
 
 bool CandidateRewriteDatabase::addTerm(Node sol,
@@ -93,9 +82,8 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
   if (eq_sol != sol)
   {
     is_unique_term = false;
-    // if eq_sol is null, then we have an uninteresting candidate rewrite,
-    // e.g. one that is alpha-equivalent to another.
-    if (!eq_sol.isNull())
+    // should we filter the pair?
+    if (!d_crewrite_filter.filterPair(sol, eq_sol))
     {
       // get the actual term
       Node solb = sol;
@@ -215,7 +203,7 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
       if (!is_unique_term)
       {
         // register this as a relevant pair (helps filtering)
-        d_sampler.registerRelevantPair(sol, eq_sol);
+        d_crewrite_filter.registerRelevantPair(sol, eq_sol);
         // The analog of terms sol and eq_sol are equivalent under
         // sample points but do not rewrite to the same term. Hence,
         // this indicates a candidate rewrite.
index a2a6c5745c2ab7d00a2af384da359f1944670f10..7f51043e23f724bbfcb0379389cfeeacb88f436b 100644 (file)
@@ -21,6 +21,7 @@
 #include <memory>
 #include <unordered_set>
 #include <vector>
+#include "theory/quantifiers/candidate_rewrite_filter.h"
 #include "theory/quantifiers/sygus_sampler.h"
 
 namespace CVC4 {
@@ -116,11 +117,9 @@ class CandidateRewriteDatabase
    * This is used for the sygusRewSynth() option to synthesize new candidate
    * rewrite rules.
    */
-  SygusSamplerExt d_sampler;
-  /** a (dummy) user context, used for d_drewrite */
-  context::UserContext d_fake_context;
-  /** dynamic rewriter class */
-  std::unique_ptr<DynamicRewriter> d_drewrite;
+  SygusSampler d_sampler;
+  /** candidate rewrite filter */
+  CandidateRewriteFilter d_crewrite_filter;
   /**
    * Cache of skolems for each free variable that appears in a synthesis check
    * (for --sygus-rr-synth-check).
diff --git a/src/theory/quantifiers/candidate_rewrite_filter.cpp b/src/theory/quantifiers/candidate_rewrite_filter.cpp
new file mode 100644 (file)
index 0000000..68a3abe
--- /dev/null
@@ -0,0 +1,413 @@
+/*********************                                                        */
+/*! \file candidate_rewrite_filter.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implements techniques for candidate rewrite rule filtering, which
+ ** filters the output of --sygus-rr-synth. The classes in this file implement
+ ** filtering based on congruence, variable ordering, and matching.
+ **/
+
+#include "theory/quantifiers/candidate_rewrite_filter.h"
+
+#include "options/base_options.h"
+#include "options/quantifiers_options.h"
+#include "printer/printer.h"
+
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace quantifiers {
+
+bool MatchTrie::getMatches(Node n, NotifyMatch* ntm)
+{
+  std::vector<Node> vars;
+  std::vector<Node> subs;
+  std::map<Node, Node> smap;
+
+  std::vector<std::vector<Node> > visit;
+  std::vector<MatchTrie*> visit_trie;
+  std::vector<int> visit_var_index;
+  std::vector<bool> visit_bound_var;
+
+  visit.push_back(std::vector<Node>{n});
+  visit_trie.push_back(this);
+  visit_var_index.push_back(-1);
+  visit_bound_var.push_back(false);
+  while (!visit.empty())
+  {
+    std::vector<Node> cvisit = visit.back();
+    MatchTrie* curr = visit_trie.back();
+    if (cvisit.empty())
+    {
+      Assert(n
+             == curr->d_data.substitute(
+                    vars.begin(), vars.end(), subs.begin(), subs.end()));
+      Trace("crf-match-debug") << "notify : " << curr->d_data << std::endl;
+      if (!ntm->notify(n, curr->d_data, vars, subs))
+      {
+        return false;
+      }
+      visit.pop_back();
+      visit_trie.pop_back();
+      visit_var_index.pop_back();
+      visit_bound_var.pop_back();
+    }
+    else
+    {
+      Node cn = cvisit.back();
+      Trace("crf-match-debug") << "traverse : " << cn << " at depth "
+                               << visit.size() << std::endl;
+      unsigned index = visit.size() - 1;
+      int vindex = visit_var_index[index];
+      if (vindex == -1)
+      {
+        if (!cn.isVar())
+        {
+          Node op = cn.hasOperator() ? cn.getOperator() : cn;
+          unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0;
+          std::map<unsigned, MatchTrie>::iterator itu =
+              curr->d_children[op].find(nchild);
+          if (itu != curr->d_children[op].end())
+          {
+            // recurse on the operator or self
+            cvisit.pop_back();
+            if (cn.hasOperator())
+            {
+              for (const Node& cnc : cn)
+              {
+                cvisit.push_back(cnc);
+              }
+            }
+            Trace("crf-match-debug") << "recurse op : " << op << std::endl;
+            visit.push_back(cvisit);
+            visit_trie.push_back(&itu->second);
+            visit_var_index.push_back(-1);
+            visit_bound_var.push_back(false);
+          }
+        }
+        visit_var_index[index]++;
+      }
+      else
+      {
+        // clean up if we previously bound a variable
+        if (visit_bound_var[index])
+        {
+          Assert(!vars.empty());
+          smap.erase(vars.back());
+          vars.pop_back();
+          subs.pop_back();
+          visit_bound_var[index] = false;
+        }
+
+        if (vindex == static_cast<int>(curr->d_vars.size()))
+        {
+          Trace("crf-match-debug")
+              << "finished checking " << curr->d_vars.size()
+              << " variables at depth " << visit.size() << std::endl;
+          // finished
+          visit.pop_back();
+          visit_trie.pop_back();
+          visit_var_index.pop_back();
+          visit_bound_var.pop_back();
+        }
+        else
+        {
+          Trace("crf-match-debug") << "check variable #" << vindex
+                                   << " at depth " << visit.size() << std::endl;
+          Assert(vindex < static_cast<int>(curr->d_vars.size()));
+          // recurse on variable?
+          Node var = curr->d_vars[vindex];
+          bool recurse = true;
+          // check if it is already bound
+          std::map<Node, Node>::iterator its = smap.find(var);
+          if (its != smap.end())
+          {
+            if (its->second != cn)
+            {
+              recurse = false;
+            }
+          }
+          else
+          {
+            vars.push_back(var);
+            subs.push_back(cn);
+            smap[var] = cn;
+            visit_bound_var[index] = true;
+          }
+          if (recurse)
+          {
+            Trace("crf-match-debug") << "recurse var : " << var << std::endl;
+            cvisit.pop_back();
+            visit.push_back(cvisit);
+            visit_trie.push_back(&curr->d_children[var][0]);
+            visit_var_index.push_back(-1);
+            visit_bound_var.push_back(false);
+          }
+          visit_var_index[index]++;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+void MatchTrie::addTerm(Node n)
+{
+  std::vector<Node> visit;
+  visit.push_back(n);
+  MatchTrie* curr = this;
+  while (!visit.empty())
+  {
+    Node cn = visit.back();
+    visit.pop_back();
+    if (cn.hasOperator())
+    {
+      curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]);
+      for (const Node& cnc : cn)
+      {
+        visit.push_back(cnc);
+      }
+    }
+    else
+    {
+      if (cn.isVar()
+          && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
+                 == curr->d_vars.end())
+      {
+        curr->d_vars.push_back(cn);
+      }
+      curr = &(curr->d_children[cn][0]);
+    }
+  }
+  curr->d_data = n;
+}
+
+void MatchTrie::clear()
+{
+  d_children.clear();
+  d_vars.clear();
+  d_data = Node::null();
+}
+
+// the number of d_drewrite objects we have allocated (to avoid name conflicts)
+static unsigned drewrite_counter = 0;
+
+CandidateRewriteFilter::CandidateRewriteFilter()
+    : d_ss(nullptr),
+      d_tds(nullptr),
+      d_use_sygus_type(false),
+      d_drewrite(nullptr),
+      d_ssenm(*this)
+{
+}
+
+void CandidateRewriteFilter::initialize(SygusSampler* ss,
+                                        TermDbSygus* tds,
+                                        bool useSygusType)
+{
+  d_ss = ss;
+  d_use_sygus_type = false;
+  d_tds = tds;
+  // initialize members of this class
+  d_match_trie.clear();
+  d_pairs.clear();
+  if (options::sygusRewSynthFilterCong())
+  {
+    // initialize the dynamic rewriter
+    std::stringstream ss;
+    ss << "_dyn_rewriter_" << drewrite_counter;
+    drewrite_counter++;
+    d_drewrite = std::unique_ptr<DynamicRewriter>(
+        new DynamicRewriter(ss.str(), &d_fake_context));
+  }
+}
+
+bool CandidateRewriteFilter::filterPair(Node n, Node eq_n)
+{
+  Node bn = n;
+  Node beq_n = eq_n;
+  if (d_use_sygus_type)
+  {
+    bn = d_tds->sygusToBuiltin(n);
+    beq_n = d_tds->sygusToBuiltin(eq_n);
+  }
+  Trace("cr-filter") << "crewriteFilter : " << bn << "..." << beq_n
+                     << std::endl;
+  // whether we will keep this pair
+  bool keep = true;
+
+  // ----- check ordering redundancy
+  if (options::sygusRewSynthFilterOrder())
+  {
+    bool nor = d_ss->isOrdered(bn);
+    bool eqor = d_ss->isOrdered(beq_n);
+    Trace("cr-filter-debug") << "Ordered? : " << nor << " " << eqor
+                             << std::endl;
+    if (eqor || nor)
+    {
+      // if only one is ordered, then we require that the ordered one's
+      // variables cannot be a strict subset of the variables of the other.
+      if (!eqor)
+      {
+        if (d_ss->containsFreeVariables(beq_n, bn, true))
+        {
+          keep = false;
+        }
+        else
+        {
+          // if the previous value stored was unordered, but n is
+          // ordered, we prefer n. Thus, we force its addition to the
+          // sampler database.
+          d_ss->registerTerm(n, true);
+        }
+      }
+      else if (!nor)
+      {
+        keep = !d_ss->containsFreeVariables(bn, beq_n, true);
+      }
+    }
+    else
+    {
+      keep = false;
+    }
+    if (!keep)
+    {
+      Trace("cr-filter") << "...redundant (unordered)" << std::endl;
+    }
+  }
+
+  // ----- check rewriting redundancy
+  if (keep && d_drewrite != nullptr)
+  {
+    Trace("cr-filter-debug") << "Check equal rewrite pair..." << std::endl;
+    if (d_drewrite->areEqual(bn, beq_n))
+    {
+      // must be unique according to the dynamic rewriter
+      Trace("cr-filter") << "...redundant (rewritable)" << std::endl;
+      keep = false;
+    }
+  }
+
+  if (options::sygusRewSynthFilterMatch())
+  {
+    // ----- check matchable
+    // check whether the pair is matchable with a previous one
+    d_curr_pair_rhs = beq_n;
+    Trace("crf-match") << "CRF check matches : " << bn << " [rhs = " << beq_n
+                       << "]..." << std::endl;
+    if (!d_match_trie.getMatches(bn, &d_ssenm))
+    {
+      keep = false;
+      Trace("cr-filter") << "...redundant (matchable)" << std::endl;
+      // regardless, would help to remember it
+      registerRelevantPair(n, eq_n);
+    }
+  }
+
+  if (keep)
+  {
+    return false;
+  }
+  if (Trace.isOn("sygus-rr-filter"))
+  {
+    Printer* p = Printer::getPrinter(options::outputLanguage());
+    std::stringstream ss;
+    ss << "(redundant-rewrite ";
+    p->toStreamSygus(ss, n);
+    ss << " ";
+    p->toStreamSygus(ss, eq_n);
+    ss << ")";
+    Trace("sygus-rr-filter") << ss.str() << std::endl;
+  }
+  return true;
+}
+
+void CandidateRewriteFilter::registerRelevantPair(Node n, Node eq_n)
+{
+  Node bn = n;
+  Node beq_n = eq_n;
+  if (d_use_sygus_type)
+  {
+    bn = d_tds->sygusToBuiltin(n);
+    beq_n = d_tds->sygusToBuiltin(eq_n);
+  }
+  // ----- check rewriting redundancy
+  if (d_drewrite != nullptr)
+  {
+    Trace("cr-filter-debug") << "Add rewrite pair..." << std::endl;
+    Assert(!d_drewrite->areEqual(bn, beq_n));
+    d_drewrite->addRewrite(bn, beq_n);
+  }
+  if (options::sygusRewSynthFilterMatch())
+  {
+    // add to match information
+    for (unsigned r = 0; r < 2; r++)
+    {
+      Node t = r == 0 ? bn : beq_n;
+      Node to = r == 0 ? beq_n : bn;
+      // insert in match trie if first time
+      if (d_pairs.find(t) == d_pairs.end())
+      {
+        Trace("crf-match") << "CRF add term : " << t << std::endl;
+        d_match_trie.addTerm(t);
+      }
+      d_pairs[t].insert(to);
+    }
+  }
+}
+
+bool CandidateRewriteFilter::notify(Node s,
+                                    Node n,
+                                    std::vector<Node>& vars,
+                                    std::vector<Node>& subs)
+{
+  Assert(!d_curr_pair_rhs.isNull());
+  std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
+      d_pairs.find(n);
+  if (Trace.isOn("crf-match"))
+  {
+    Trace("crf-match") << "  " << s << " matches " << n
+                       << " under:" << std::endl;
+    for (unsigned i = 0, size = vars.size(); i < size; i++)
+    {
+      Trace("crf-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
+      // TODO (#1923) ensure that we use an internal representation to
+      // ensure polymorphism is handled correctly
+      Assert(vars[i].getType().isComparableTo(subs[i].getType()));
+    }
+  }
+  Assert(it != d_pairs.end());
+  for (const Node& nr : it->second)
+  {
+    Node nrs =
+        nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+    bool areEqual = (nrs == d_curr_pair_rhs);
+    if (!areEqual && d_drewrite != nullptr)
+    {
+      // if dynamic rewriter is available, consult it
+      areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
+    }
+    if (areEqual)
+    {
+      Trace("crf-match") << "*** Match, current pair: " << std::endl;
+      Trace("crf-match") << "  (" << s << ", " << d_curr_pair_rhs << ")"
+                         << std::endl;
+      Trace("crf-match") << "is an instance of previous pair:" << std::endl;
+      Trace("crf-match") << "  (" << n << ", " << nr << ")" << std::endl;
+      return false;
+    }
+  }
+  return true;
+}
+
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/quantifiers/candidate_rewrite_filter.h b/src/theory/quantifiers/candidate_rewrite_filter.h
new file mode 100644 (file)
index 0000000..9a09680
--- /dev/null
@@ -0,0 +1,218 @@
+/*********************                                                        */
+/*! \file candidate_rewrite_filter.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief Implements techniques for candidate rewrite rule filtering, which
+ ** filters the output of --sygus-rr-synth. The classes in this file implement
+ ** filtering based on congruence, variable ordering, and matching.
+ **/
+
+#include "cvc4_private.h"
+
+#ifndef __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H
+#define __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H
+
+#include <map>
+#include "theory/quantifiers/dynamic_rewrite.h"
+#include "theory/quantifiers/sygus/term_database_sygus.h"
+#include "theory/quantifiers/sygus_sampler.h"
+
+namespace CVC4 {
+namespace theory {
+namespace quantifiers {
+
+/** A virtual class for notifications regarding matches. */
+class NotifyMatch
+{
+ public:
+  virtual ~NotifyMatch() {}
+  /**
+   * A notification that s is equal to n * { vars -> subs }. This function
+   * should return false if we do not wish to be notified of further matches.
+   */
+  virtual bool notify(Node s,
+                      Node n,
+                      std::vector<Node>& vars,
+                      std::vector<Node>& subs) = 0;
+};
+
+/**
+ * A trie (discrimination tree) storing a set of terms S, that can be used to
+ * query, for a given term t, all terms s from S that are matchable with t,
+ * that is s*sigma = t for some substitution sigma.
+ */
+class MatchTrie
+{
+ public:
+  /** Get matches
+   *
+   * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this
+   * trie that is matchable with n where s = n * { vars -> subs } for some
+   * vars, subs. This function returns false if one of these calls to notify
+   * returns false.
+   */
+  bool getMatches(Node n, NotifyMatch* ntm);
+  /** Adds node n to this trie */
+  void addTerm(Node n);
+  /** Clear this trie */
+  void clear();
+
+ private:
+  /**
+   * The children of this node in the trie. Terms t are indexed by a
+   * depth-first (right to left) traversal on its subterms, where the
+   * top-symbol of t is indexed by:
+   * - (operator, #children) if t has an operator, or
+   * - (t, 0) if t does not have an operator.
+   */
+  std::map<Node, std::map<unsigned, MatchTrie> > d_children;
+  /** The set of variables in the domain of d_children */
+  std::vector<Node> d_vars;
+  /** The data of this node in the trie */
+  Node d_data;
+};
+
+/** candidate rewrite filter
+ *
+ * This class is responsible for various filtering techniques for candidate
+ * rewrite rules, including:
+ * (1) filtering based on variable ordering,
+ * (2) filtering based on congruence,
+ * (3) filtering based on matching.
+ * For details, see Reynolds et al "Rewrites for SMT Solvers using Syntax-Guided
+ * Enumeration", SMT 2018.
+ *
+ * In the following, we assume that the registerRelevantPair method of this
+ * class been called for some pairs of terms. For each such call to
+ * registerRelevantPair( t, s ), we say that (t,s) and (s,t) are "relevant
+ * pairs". A relevant pair ( t, s ) typically corresponds to a (candidate)
+ * rewrite t = s.
+ */
+class CandidateRewriteFilter
+{
+ public:
+  CandidateRewriteFilter();
+
+  /** initialize
+   *
+   * Initializes this class, ss is the sygus sampler that this class is
+   * filtering rewrite rule pairs for, and tds is a point to the sygus term
+   * database utility class.
+   *
+   * If useSygusType is false, this means that the terms in filterPair and
+   * registerRelevantPair calls should be interpreted as themselves. Otherwise,
+   * if useSygusType is true, these terms should be interpreted as their
+   * analog according to the deep embedding.
+   */
+  void initialize(SygusSampler* ss, TermDbSygus* tds, bool useSygusType);
+  /** filter pair
+   *
+   * This method returns true if the pair (n, eq_n) should be filtered. If it
+   * is not filtered, then the caller may choose to call
+   * registerRelevantPair(n, eq_n) below, although it may not, say if it finds
+   * another reason to discard the pair.
+   *
+   * If this method returns false, then for all previous relevant pairs
+   * ( a, eq_a ), we have that n = eq_n is not an instance of a = eq_a
+   * modulo symmetry of equality, nor is n = eq_n derivable from the set of
+   * all previous relevant pairs. The latter is determined by the d_drewrite
+   * utility. For example:
+   * [1]  ( t+0, t ) and ( x+0, x )
+   * will not both be relevant pairs of this function since t+0=t is
+   * an instance of x+0=x.
+   * [2]  ( s, t ) and ( s+0, t+0 )
+   * will not both be relevant pairs of this function since s+0=t+0 is
+   * derivable from s=t.
+   * These two criteria may be combined, for example:
+   * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s )
+   * are relevant pairs, since t+0 is an instance of x+0 where
+   * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair
+   * above (t+s = s).
+   */
+  bool filterPair(Node n, Node eq_n);
+  /** register relevant pair
+   *
+   * This should be called after filterPair( n, eq_n ) returns false.
+   * This registers ( n, eq_n ) as a relevant pair with this class.
+   */
+  void registerRelevantPair(Node n, Node eq_n);
+
+ private:
+  /** pointer to the sygus sampler that this class is filtering rewrites for */
+  SygusSampler* d_ss;
+  /** pointer to the sygus term database, used if d_use_sygus_type is true */
+  TermDbSygus* d_tds;
+  /** whether we are registering sygus terms with this class */
+  bool d_use_sygus_type;
+
+  //----------------------------congruence filtering
+  /** a (dummy) user context, used for d_drewrite */
+  context::UserContext d_fake_context;
+  /** dynamic rewriter class */
+  std::unique_ptr<DynamicRewriter> d_drewrite;
+  //----------------------------end congruence filtering
+
+  //----------------------------match filtering
+  /**
+   * Stores all relevant pairs returned by this sampler (see registerTerm). In
+   * detail, if (t,s) is a relevant pair, then t in d_pairs[s].
+   */
+  std::map<Node, std::unordered_set<Node, NodeHashFunction> > d_pairs;
+  /** Match trie storing all terms in the domain of d_pairs. */
+  MatchTrie d_match_trie;
+  /** Notify class */
+  class CandidateRewriteFilterNotifyMatch : public NotifyMatch
+  {
+    CandidateRewriteFilter& d_sse;
+
+   public:
+    CandidateRewriteFilterNotifyMatch(CandidateRewriteFilter& sse) : d_sse(sse)
+    {
+    }
+    /** notify match */
+    bool notify(Node n,
+                Node s,
+                std::vector<Node>& vars,
+                std::vector<Node>& subs) override
+    {
+      return d_sse.notify(n, s, vars, subs);
+    }
+  };
+  /** Notify object used for reporting matches from d_match_trie */
+  CandidateRewriteFilterNotifyMatch d_ssenm;
+  /**
+   * Stores the current right hand side of a pair we are considering.
+   *
+   * In more detail, in registerTerm, we are interested in whether a pair (s,t)
+   * is a relevant pair. We do this by:
+   * (1) Setting the node d_curr_pair_rhs to t,
+   * (2) Using d_match_trie, compute all terms s1...sn that match s.
+   * For each si, where s = si * sigma for some substitution sigma, we check
+   * whether t = ti * sigma for some previously relevant pair (si,ti), in
+   * which case (s,t) is an instance of (si,ti).
+   */
+  Node d_curr_pair_rhs;
+  /**
+   * Called by the above class during d_match_trie.getMatches( s ), when we
+   * find that si = s * sigma, where si is a term that is stored in
+   * d_match_trie.
+   *
+   * This function returns false if ( s, d_curr_pair_rhs ) is an instance of
+   * previously relevant pair.
+   */
+  bool notify(Node s, Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+  //----------------------------end match filtering
+};
+
+}  // namespace quantifiers
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* __CVC4__THEORY__QUANTIFIERS__CANDIDATE_REWRITE_FILTER_H */
index e07f735407439f58d489de11670c4b086f9e860e..ebd10c5852c07af51483533532b9133e097e7935 100644 (file)
@@ -746,378 +746,6 @@ void SygusSampler::registerSygusType(TypeNode tn)
   }
 }
 
-SygusSamplerExt::SygusSamplerExt() : d_drewrite(nullptr), d_ssenm(*this) {}
-
-void SygusSamplerExt::initializeSygus(TermDbSygus* tds,
-                                      Node f,
-                                      unsigned nsamples,
-                                      bool useSygusType)
-{
-  SygusSampler::initializeSygus(tds, f, nsamples, useSygusType);
-  d_pairs.clear();
-  d_match_trie.clear();
-}
-
-void SygusSamplerExt::setDynamicRewriter(DynamicRewriter* dr)
-{
-  d_drewrite = dr;
-}
-
-Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
-{
-  Node eq_n = SygusSampler::registerTerm(n, forceKeep);
-  if (eq_n == n)
-  {
-    // this is a unique term
-    return n;
-  }
-  Node bn = n;
-  Node beq_n = eq_n;
-  if (d_use_sygus_type)
-  {
-    bn = d_tds->sygusToBuiltin(n);
-    beq_n = d_tds->sygusToBuiltin(eq_n);
-  }
-  Trace("sygus-synth-rr") << "sygusSampleExt : " << bn << "..." << beq_n
-                          << std::endl;
-  // whether we will keep this pair
-  bool keep = true;
-
-  // ----- check ordering redundancy
-  if (options::sygusRewSynthFilterOrder())
-  {
-    bool nor = isOrdered(bn);
-    bool eqor = isOrdered(beq_n);
-    Trace("sygus-synth-rr-debug") << "Ordered? : " << nor << " " << eqor
-                                  << std::endl;
-    if (eqor || nor)
-    {
-      // if only one is ordered, then we require that the ordered one's
-      // variables cannot be a strict subset of the variables of the other.
-      if (!eqor)
-      {
-        if (containsFreeVariables(beq_n, bn, true))
-        {
-          keep = false;
-        }
-        else
-        {
-          // if the previous value stored was unordered, but n is
-          // ordered, we prefer n. Thus, we force its addition to the
-          // sampler database.
-          SygusSampler::registerTerm(n, true);
-        }
-      }
-      else if (!nor)
-      {
-        keep = !containsFreeVariables(bn, beq_n, true);
-      }
-    }
-    else
-    {
-      keep = false;
-    }
-    if (!keep)
-    {
-      Trace("sygus-synth-rr") << "...redundant (unordered)" << std::endl;
-    }
-  }
-
-  // ----- check rewriting redundancy
-  if (keep && d_drewrite != nullptr)
-  {
-    Trace("sygus-synth-rr-debug") << "Check equal rewrite pair..." << std::endl;
-    if (d_drewrite->areEqual(bn, beq_n))
-    {
-      // must be unique according to the dynamic rewriter
-      Trace("sygus-synth-rr") << "...redundant (rewritable)" << std::endl;
-      keep = false;
-    }
-  }
-
-  if (options::sygusRewSynthFilterMatch())
-  {
-    // ----- check matchable
-    // check whether the pair is matchable with a previous one
-    d_curr_pair_rhs = beq_n;
-    Trace("sse-match") << "SSE check matches : " << bn << " [rhs = " << beq_n
-                       << "]..." << std::endl;
-    if (!d_match_trie.getMatches(bn, &d_ssenm))
-    {
-      keep = false;
-      Trace("sygus-synth-rr") << "...redundant (matchable)" << std::endl;
-      // regardless, would help to remember it
-      registerRelevantPair(n, eq_n);
-    }
-  }
-
-  if (keep)
-  {
-    return eq_n;
-  }
-  if (Trace.isOn("sygus-rr-filter"))
-  {
-    Printer* p = Printer::getPrinter(options::outputLanguage());
-    std::stringstream ss;
-    ss << "(redundant-rewrite ";
-    p->toStreamSygus(ss, n);
-    ss << " ";
-    p->toStreamSygus(ss, eq_n);
-    ss << ")";
-    Trace("sygus-rr-filter") << ss.str() << std::endl;
-  }
-  return Node::null();
-}
-
-void SygusSamplerExt::registerRelevantPair(Node n, Node eq_n)
-{
-  Node bn = n;
-  Node beq_n = eq_n;
-  if (d_use_sygus_type)
-  {
-    bn = d_tds->sygusToBuiltin(n);
-    beq_n = d_tds->sygusToBuiltin(eq_n);
-  }
-  // ----- check rewriting redundancy
-  if (d_drewrite != nullptr)
-  {
-    Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl;
-    Assert(!d_drewrite->areEqual(bn, beq_n));
-    d_drewrite->addRewrite(bn, beq_n);
-  }
-  if (options::sygusRewSynthFilterMatch())
-  {
-    // add to match information
-    for (unsigned r = 0; r < 2; r++)
-    {
-      Node t = r == 0 ? bn : beq_n;
-      Node to = r == 0 ? beq_n : bn;
-      // insert in match trie if first time
-      if (d_pairs.find(t) == d_pairs.end())
-      {
-        Trace("sse-match") << "SSE add term : " << t << std::endl;
-        d_match_trie.addTerm(t);
-      }
-      d_pairs[t].insert(to);
-    }
-  }
-}
-
-bool SygusSamplerExt::notify(Node s,
-                             Node n,
-                             std::vector<Node>& vars,
-                             std::vector<Node>& subs)
-{
-  Assert(!d_curr_pair_rhs.isNull());
-  std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
-      d_pairs.find(n);
-  if (Trace.isOn("sse-match"))
-  {
-    Trace("sse-match") << "  " << s << " matches " << n
-                       << " under:" << std::endl;
-    for (unsigned i = 0, size = vars.size(); i < size; i++)
-    {
-      Trace("sse-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
-      // TODO (#1923) ensure that we use an internal representation to
-      // ensure polymorphism is handled correctly
-      Assert(vars[i].getType().isComparableTo(subs[i].getType()));
-    }
-  }
-  Assert(it != d_pairs.end());
-  for (const Node& nr : it->second)
-  {
-    Node nrs =
-        nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
-    bool areEqual = (nrs == d_curr_pair_rhs);
-    if (!areEqual && d_drewrite != nullptr)
-    {
-      // if dynamic rewriter is available, consult it
-      areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
-    }
-    if (areEqual)
-    {
-      Trace("sse-match") << "*** Match, current pair: " << std::endl;
-      Trace("sse-match") << "  (" << s << ", " << d_curr_pair_rhs << ")"
-                         << std::endl;
-      Trace("sse-match") << "is an instance of previous pair:" << std::endl;
-      Trace("sse-match") << "  (" << n << ", " << nr << ")" << std::endl;
-      return false;
-    }
-  }
-  return true;
-}
-
-bool MatchTrie::getMatches(Node n, NotifyMatch* ntm)
-{
-  std::vector<Node> vars;
-  std::vector<Node> subs;
-  std::map<Node, Node> smap;
-
-  std::vector<std::vector<Node> > visit;
-  std::vector<MatchTrie*> visit_trie;
-  std::vector<int> visit_var_index;
-  std::vector<bool> visit_bound_var;
-
-  visit.push_back(std::vector<Node>{n});
-  visit_trie.push_back(this);
-  visit_var_index.push_back(-1);
-  visit_bound_var.push_back(false);
-  while (!visit.empty())
-  {
-    std::vector<Node> cvisit = visit.back();
-    MatchTrie* curr = visit_trie.back();
-    if (cvisit.empty())
-    {
-      Assert(n
-             == curr->d_data.substitute(
-                    vars.begin(), vars.end(), subs.begin(), subs.end()));
-      Trace("sse-match-debug") << "notify : " << curr->d_data << std::endl;
-      if (!ntm->notify(n, curr->d_data, vars, subs))
-      {
-        return false;
-      }
-      visit.pop_back();
-      visit_trie.pop_back();
-      visit_var_index.pop_back();
-      visit_bound_var.pop_back();
-    }
-    else
-    {
-      Node cn = cvisit.back();
-      Trace("sse-match-debug")
-          << "traverse : " << cn << " at depth " << visit.size() << std::endl;
-      unsigned index = visit.size() - 1;
-      int vindex = visit_var_index[index];
-      if (vindex == -1)
-      {
-        if (!cn.isVar())
-        {
-          Node op = cn.hasOperator() ? cn.getOperator() : cn;
-          unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0;
-          std::map<unsigned, MatchTrie>::iterator itu =
-              curr->d_children[op].find(nchild);
-          if (itu != curr->d_children[op].end())
-          {
-            // recurse on the operator or self
-            cvisit.pop_back();
-            if (cn.hasOperator())
-            {
-              for (const Node& cnc : cn)
-              {
-                cvisit.push_back(cnc);
-              }
-            }
-            Trace("sse-match-debug") << "recurse op : " << op << std::endl;
-            visit.push_back(cvisit);
-            visit_trie.push_back(&itu->second);
-            visit_var_index.push_back(-1);
-            visit_bound_var.push_back(false);
-          }
-        }
-        visit_var_index[index]++;
-      }
-      else
-      {
-        // clean up if we previously bound a variable
-        if (visit_bound_var[index])
-        {
-          Assert(!vars.empty());
-          smap.erase(vars.back());
-          vars.pop_back();
-          subs.pop_back();
-          visit_bound_var[index] = false;
-        }
-
-        if (vindex == static_cast<int>(curr->d_vars.size()))
-        {
-          Trace("sse-match-debug")
-              << "finished checking " << curr->d_vars.size()
-              << " variables at depth " << visit.size() << std::endl;
-          // finished
-          visit.pop_back();
-          visit_trie.pop_back();
-          visit_var_index.pop_back();
-          visit_bound_var.pop_back();
-        }
-        else
-        {
-          Trace("sse-match-debug") << "check variable #" << vindex
-                                   << " at depth " << visit.size() << std::endl;
-          Assert(vindex < static_cast<int>(curr->d_vars.size()));
-          // recurse on variable?
-          Node var = curr->d_vars[vindex];
-          bool recurse = true;
-          // check if it is already bound
-          std::map<Node, Node>::iterator its = smap.find(var);
-          if (its != smap.end())
-          {
-            if (its->second != cn)
-            {
-              recurse = false;
-            }
-          }
-          else
-          {
-            vars.push_back(var);
-            subs.push_back(cn);
-            smap[var] = cn;
-            visit_bound_var[index] = true;
-          }
-          if (recurse)
-          {
-            Trace("sse-match-debug") << "recurse var : " << var << std::endl;
-            cvisit.pop_back();
-            visit.push_back(cvisit);
-            visit_trie.push_back(&curr->d_children[var][0]);
-            visit_var_index.push_back(-1);
-            visit_bound_var.push_back(false);
-          }
-          visit_var_index[index]++;
-        }
-      }
-    }
-  }
-  return true;
-}
-
-void MatchTrie::addTerm(Node n)
-{
-  std::vector<Node> visit;
-  visit.push_back(n);
-  MatchTrie* curr = this;
-  while (!visit.empty())
-  {
-    Node cn = visit.back();
-    visit.pop_back();
-    if (cn.hasOperator())
-    {
-      curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]);
-      for (const Node& cnc : cn)
-      {
-        visit.push_back(cnc);
-      }
-    }
-    else
-    {
-      if (cn.isVar()
-          && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
-                 == curr->d_vars.end())
-      {
-        curr->d_vars.push_back(cn);
-      }
-      curr = &(curr->d_children[cn][0]);
-    }
-  }
-  curr->d_data = n;
-}
-
-void MatchTrie::clear()
-{
-  d_children.clear();
-  d_vars.clear();
-  d_data = Node::null();
-}
-
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */
index 290a8b17dc056f01f4e2849570c80e1f9cc476d7..0134b3a86c652cc5dde67f0eba0db98a0d1fea7d 100644 (file)
@@ -19,7 +19,6 @@
 
 #include <map>
 #include "theory/evaluator.h"
-#include "theory/quantifiers/dynamic_rewrite.h"
 #include "theory/quantifiers/lazy_trie.h"
 #include "theory/quantifiers/sygus/term_database_sygus.h"
 #include "theory/quantifiers/term_enumeration.h"
@@ -28,7 +27,6 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-
 /** SygusSampler
  *
  * This class can be used to test whether two expressions are equivalent
@@ -124,7 +122,7 @@ class SygusSampler : public LazyTrieEvaluator
    */
   int getDiffSamplePointIndex(Node a, Node b);
 
- protected:
+  //--------------------------queries about terms
   /** is contiguous
    *
    * This returns whether n's free variables (terms occurring in the range of
@@ -149,6 +147,7 @@ class SygusSampler : public LazyTrieEvaluator
    * occur in the range d_type_vars.
    */
   bool containsFreeVariables(Node a, Node b, bool strict = false);
+  //--------------------------end queries about terms
 
  protected:
   /** sygus term database of d_qe */
@@ -286,167 +285,6 @@ class SygusSampler : public LazyTrieEvaluator
   void registerSygusType(TypeNode tn);
 };
 
-/** A virtual class for notifications regarding matches. */
-class NotifyMatch
-{
- public:
-  virtual ~NotifyMatch() {}
-
-  /**
-   * A notification that s is equal to n * { vars -> subs }. This function
-   * should return false if we do not wish to be notified of further matches.
-   */
-  virtual bool notify(Node s,
-                      Node n,
-                      std::vector<Node>& vars,
-                      std::vector<Node>& subs) = 0;
-};
-
-/**
- * A trie (discrimination tree) storing a set of terms S, that can be used to
- * query, for a given term t, all terms from S that are matchable with t.
- */
-class MatchTrie
-{
- public:
-  /** Get matches
-   *
-   * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this
-   * trie that is matchable with n where s = n * { vars -> subs } for some
-   * vars, subs. This function returns false if one of these calls to notify
-   * returns false.
-   */
-  bool getMatches(Node n, NotifyMatch* ntm);
-  /** Adds node n to this trie */
-  void addTerm(Node n);
-  /** Clear this trie */
-  void clear();
-
- private:
-  /**
-   * The children of this node in the trie. Terms t are indexed by a
-   * depth-first (right to left) traversal on its subterms, where the
-   * top-symbol of t is indexed by:
-   * - (operator, #children) if t has an operator, or
-   * - (t, 0) if t does not have an operator.
-   */
-  std::map<Node, std::map<unsigned, MatchTrie> > d_children;
-  /** The set of variables in the domain of d_children */
-  std::vector<Node> d_vars;
-  /** The data of this node in the trie */
-  Node d_data;
-};
-
-/** Version of the above class with some additional features */
-class SygusSamplerExt : public SygusSampler
-{
- public:
-  SygusSamplerExt();
-  /** initialize */
-  void initializeSygus(TermDbSygus* tds,
-                       Node f,
-                       unsigned nsamples,
-                       bool useSygusType) override;
-  /** set dynamic rewriter
-   *
-   * This tells this class to use the dynamic rewriter object dr. This utility
-   * is used to query whether pairs of terms are already entailed to be
-   * equal based on previous rewrite rules.
-   */
-  void setDynamicRewriter(DynamicRewriter* dr);
-
-  /** register term n with this sampler database
-   *
-   *  For each call to registerTerm( t, ... ) that returns s, we say that
-   * (t,s) and (s,t) are "relevant pairs".
-   *
-   * This returns either null, or a term ret with the same guarantees as
-   * SygusSampler::registerTerm with the additional guarantee
-   * that for all previous relevant pairs ( n', nret' ),
-   * we have that n = ret is not an instance of n' = ret'
-   * modulo symmetry of equality, nor is n = ret derivable from the set of
-   * all previous relevant pairs. The latter is determined by the d_drewrite
-   * utility. For example:
-   * [1]  ( t+0, t ) and ( x+0, x )
-   * will not both be relevant pairs of this function since t+0=t is
-   * an instance of x+0=x.
-   * [2]  ( s, t ) and ( s+0, t+0 )
-   * will not both be relevant pairs of this function since s+0=t+0 is
-   * derivable from s=t.
-   * These two criteria may be combined, for example:
-   * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s )
-   * are relevant pairs, since t+0 is an instance of x+0 where
-   * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair
-   * above (t+s = s).
-   *
-   * If this function returns null, then n is equivalent to a previously
-   * registered term ret, and the equality ( n, ret ) is either an instance
-   * of a previous relevant pair ( n', ret' ), or n = ret is derivable
-   * from the set of all previous relevant pairs based on the
-   * d_drewrite utility, or is an instance of a previous pair
-   */
-  Node registerTerm(Node n, bool forceKeep = false) override;
-  /** register relevant pair
-   *
-   * This should be called after registerTerm( n ) returns eq_n.
-   * This registers ( n, eq_n ) as a relevant pair with this class.
-   */
-  void registerRelevantPair(Node n, Node eq_n);
-
- private:
-  /** pointer to the dynamic rewriter class */
-  DynamicRewriter* d_drewrite;
-
-  //----------------------------match filtering
-  /**
-   * Stores all relevant pairs returned by this sampler (see registerTerm). In
-   * detail, if (t,s) is a relevant pair, then t in d_pairs[s].
-   */
-  std::map<Node, std::unordered_set<Node, NodeHashFunction> > d_pairs;
-  /** Match trie storing all terms in the domain of d_pairs. */
-  MatchTrie d_match_trie;
-  /** Notify class */
-  class SygusSamplerExtNotifyMatch : public NotifyMatch
-  {
-    SygusSamplerExt& d_sse;
-
-   public:
-    SygusSamplerExtNotifyMatch(SygusSamplerExt& sse) : d_sse(sse) {}
-    /** notify match */
-    bool notify(Node n,
-                Node s,
-                std::vector<Node>& vars,
-                std::vector<Node>& subs) override
-    {
-      return d_sse.notify(n, s, vars, subs);
-    }
-  };
-  /** Notify object used for reporting matches from d_match_trie */
-  SygusSamplerExtNotifyMatch d_ssenm;
-  /**
-   * Stores the current right hand side of a pair we are considering.
-   *
-   * In more detail, in registerTerm, we are interested in whether a pair (s,t)
-   * is a relevant pair. We do this by:
-   * (1) Setting the node d_curr_pair_rhs to t,
-   * (2) Using d_match_trie, compute all terms s1...sn that match s.
-   * For each si, where s = si * sigma for some substitution sigma, we check
-   * whether t = ti * sigma for some previously relevant pair (si,ti), in
-   * which case (s,t) is an instance of (si,ti).
-   */
-  Node d_curr_pair_rhs;
-  /**
-   * Called by the above class during d_match_trie.getMatches( s ), when we
-   * find that si = s * sigma, where si is a term that is stored in
-   * d_match_trie.
-   *
-   * This function returns false if ( s, d_curr_pair_rhs ) is an instance of
-   * previously relevant pair.
-   */
-  bool notify(Node s, Node n, std::vector<Node>& vars, std::vector<Node>& subs);
-  //----------------------------end match filtering
-};
-
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */