Filter candidate rewrites based on matching (#1682)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 27 Mar 2018 16:53:49 +0000 (11:53 -0500)
committerGitHub <noreply@github.com>
Tue, 27 Mar 2018 16:53:49 +0000 (11:53 -0500)
src/theory/quantifiers/dynamic_rewrite.cpp
src/theory/quantifiers/dynamic_rewrite.h
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index 3462a4d102aafb4735f607b9c0fd269a17805fa6..cb7379910318ffd2b2245333c9aea0d2c44ffaf7 100644 (file)
@@ -66,6 +66,20 @@ bool DynamicRewriter::addRewrite(Node a, Node b)
   return true;
 }
 
+bool DynamicRewriter::areEqual(Node a, Node b)
+{
+  if (a == b)
+  {
+    return true;
+  }
+  // add to the equality engine
+  Node ai = toInternal(a);
+  Node bi = toInternal(b);
+  d_equalityEngine.addTerm(ai);
+  d_equalityEngine.addTerm(bi);
+  return d_equalityEngine.areEqual(ai, bi);
+}
+
 Node DynamicRewriter::toInternal(Node a)
 {
   std::map<Node, Node>::iterator it = d_term_to_internal.find(a);
index 2b546415169834e586324c318cb41464ce5fccaf..388173829baa97d0c6264b504d048870b71ddcf8 100644 (file)
@@ -63,6 +63,10 @@ class DynamicRewriter
    * a = b based on the previous equalities it has seen.
    */
   bool addRewrite(Node a, Node b);
+  /**
+   * Check whether this class knows that the equality a = b holds.
+   */
+  bool areEqual(Node a, Node b);
 
  private:
   /** pointer to the quantifiers engine */
index afbdc42e18cfa12973fa637b4878921077b6d17c..99494657f2e1210a0256978887acfc4aad2e2784 100644 (file)
@@ -678,6 +678,8 @@ void SygusSampler::registerSygusType(TypeNode tn)
   }
 }
 
+SygusSamplerExt::SygusSamplerExt() : d_ssenm(*this) {}
+
 void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
                                          Node f,
                                          unsigned nsamples,
@@ -691,6 +693,8 @@ void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
   ss << f;
   d_drewrite =
       std::unique_ptr<DynamicRewriter>(new DynamicRewriter(ss.str(), qe));
+  d_pairs.clear();
+  d_match_trie.clear();
 }
 
 Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
@@ -700,6 +704,7 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
                           << std::endl;
   if (eq_n == n)
   {
+    // this is a unique term
     return n;
   }
   Node bn = n;
@@ -709,63 +714,268 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
     bn = d_tds->sygusToBuiltin(n);
     beq_n = d_tds->sygusToBuiltin(eq_n);
   }
-  // one of eq_n or n must be ordered
-  bool eqor = isOrdered(beq_n);
-  bool nor = isOrdered(bn);
-  Trace("sygus-synth-rr-debug")
-      << "Ordered? : " << nor << " " << eqor << std::endl;
-  bool isUnique = false;
-  if (eqor || nor)
+  // whether we will keep this pair
+  bool keep = true;
+
+  // ----- check matchable
+  // check whether the pair is matchable with a previous one
+  d_curr_pair_rhs = beq_n;
+  Trace("sse-match") << "SSE check matches : " << n << " [rhs = " << eq_n
+                     << "]..." << std::endl;
+  if (!d_match_trie.getMatches(bn, &d_ssenm))
   {
-    isUnique = true;
-    // if only one is ordered, then the ordered one must contain the
-    // free variables of the other
-    if (!eqor)
-    {
-      isUnique = containsFreeVariables(bn, beq_n);
-    }
-    else if (!nor)
-    {
-      isUnique = containsFreeVariables(beq_n, bn);
-    }
+    keep = false;
+    Trace("sygus-synth-rr-debug") << "...redundant (matchable)" << std::endl;
   }
-  Trace("sygus-synth-rr-debug") << "AlphaEq unique: " << isUnique << std::endl;
-  bool rewRedundant = false;
+
+  // ----- check rewriting redundancy
   if (d_drewrite != nullptr)
   {
-    Trace("sygus-synth-rr-debug") << "Add rewrite..." << std::endl;
+    Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl;
     if (!d_drewrite->addRewrite(bn, beq_n))
     {
-      rewRedundant = isUnique;
       // must be unique according to the dynamic rewriter
-      isUnique = false;
+      keep = false;
+      Trace("sygus-synth-rr-debug") << "...redundant (rewritable)" << std::endl;
     }
   }
-  Trace("sygus-synth-rr-debug") << "Rewrite unique: " << isUnique << std::endl;
 
-  if (isUnique)
+  if (keep)
   {
-    // if the previous value stored was unordered, but this is
-    // ordered, we prefer this one. Thus, we force its addition to the
-    // sampler database.
-    if (!eqor)
+    // add to match information
+    for (unsigned r = 0; r < 2; r++)
     {
-      SygusSampler::registerTerm(n, true);
+      Node t = r == 0 ? bn : beq_n;
+      Node to = r == 0 ? beq_n : bn;
+      // insert in match trie if first time
+      if (d_pairs.find(t) == d_pairs.end())
+      {
+        Trace("sse-match") << "SSE add term : " << t << std::endl;
+        d_match_trie.addTerm(t);
+      }
+      d_pairs[t].insert(to);
     }
     return eq_n;
   }
   else if (Trace.isOn("sygus-synth-rr"))
   {
-    Trace("sygus-synth-rr") << "Redundant rewrite : " << eq_n << " " << n;
-    if (rewRedundant)
-    {
-      Trace("sygus-synth-rr") << " (by rewriting)";
-    }
+    Trace("sygus-synth-rr") << "Redundant pair : " << eq_n << " " << n;
     Trace("sygus-synth-rr") << std::endl;
   }
   return Node::null();
 }
 
+bool SygusSamplerExt::notify(Node s,
+                             Node n,
+                             std::vector<Node>& vars,
+                             std::vector<Node>& subs)
+{
+  Assert(!d_curr_pair_rhs.isNull());
+  std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
+      d_pairs.find(n);
+  if (Trace.isOn("sse-match"))
+  {
+    Trace("sse-match") << "  " << s << " matches " << n
+                       << " under:" << std::endl;
+    for (unsigned i = 0, size = vars.size(); i < size; i++)
+    {
+      Trace("sse-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
+    }
+  }
+  Assert(it != d_pairs.end());
+  for (const Node& nr : it->second)
+  {
+    Node nrs =
+        nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+    bool areEqual = (nrs == d_curr_pair_rhs);
+    if (!areEqual && d_drewrite != nullptr)
+    {
+      // if dynamic rewriter is available, consult it
+      areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
+    }
+    if (areEqual)
+    {
+      Trace("sse-match") << "*** Match, current pair: " << std::endl;
+      Trace("sse-match") << "  (" << s << ", " << d_curr_pair_rhs << ")"
+                         << std::endl;
+      Trace("sse-match") << "is an instance of previous pair:" << std::endl;
+      Trace("sse-match") << "  (" << n << ", " << nr << ")" << std::endl;
+      return false;
+    }
+  }
+  return true;
+}
+
+bool MatchTrie::getMatches(Node n, NotifyMatch* ntm)
+{
+  std::vector<Node> vars;
+  std::vector<Node> subs;
+  std::map<Node, Node> smap;
+
+  std::vector<std::vector<Node> > visit;
+  std::vector<MatchTrie*> visit_trie;
+  std::vector<int> visit_var_index;
+  std::vector<bool> visit_bound_var;
+
+  visit.push_back(std::vector<Node>{n});
+  visit_trie.push_back(this);
+  visit_var_index.push_back(-1);
+  visit_bound_var.push_back(false);
+  while (!visit.empty())
+  {
+    std::vector<Node> cvisit = visit.back();
+    MatchTrie* curr = visit_trie.back();
+    if (cvisit.empty())
+    {
+      Assert(n
+             == curr->d_data.substitute(
+                    vars.begin(), vars.end(), subs.begin(), subs.end()));
+      Trace("sse-match-debug") << "notify : " << curr->d_data << std::endl;
+      if (!ntm->notify(n, curr->d_data, vars, subs))
+      {
+        return false;
+      }
+      visit.pop_back();
+      visit_trie.pop_back();
+      visit_var_index.pop_back();
+      visit_bound_var.pop_back();
+    }
+    else
+    {
+      Node cn = cvisit.back();
+      Trace("sse-match-debug")
+          << "traverse : " << cn << " at depth " << visit.size() << std::endl;
+      unsigned index = visit.size() - 1;
+      int vindex = visit_var_index[index];
+      if (vindex == -1)
+      {
+        if (!cn.isVar())
+        {
+          Node op = cn.hasOperator() ? cn.getOperator() : cn;
+          unsigned nchild = cn.hasOperator() ? cn.getNumChildren() : 0;
+          std::map<unsigned, MatchTrie>::iterator itu =
+              curr->d_children[op].find(nchild);
+          if (itu != curr->d_children[op].end())
+          {
+            // recurse on the operator or self
+            cvisit.pop_back();
+            if (cn.hasOperator())
+            {
+              for (const Node& cnc : cn)
+              {
+                cvisit.push_back(cnc);
+              }
+            }
+            Trace("sse-match-debug") << "recurse op : " << op << std::endl;
+            visit.push_back(cvisit);
+            visit_trie.push_back(&itu->second);
+            visit_var_index.push_back(-1);
+            visit_bound_var.push_back(false);
+          }
+        }
+        visit_var_index[index]++;
+      }
+      else
+      {
+        // clean up if we previously bound a variable
+        if (visit_bound_var[index])
+        {
+          Assert(!vars.empty());
+          smap.erase(vars.back());
+          vars.pop_back();
+          subs.pop_back();
+        }
+
+        if (vindex == static_cast<int>(curr->d_vars.size()))
+        {
+          Trace("sse-match-debug")
+              << "finished checking " << curr->d_vars.size()
+              << " variables at depth " << visit.size() << std::endl;
+          // finished
+          visit.pop_back();
+          visit_trie.pop_back();
+          visit_var_index.pop_back();
+          visit_bound_var.pop_back();
+        }
+        else
+        {
+          Trace("sse-match-debug") << "check variable #" << vindex
+                                   << " at depth " << visit.size() << std::endl;
+          Assert(vindex < static_cast<int>(curr->d_vars.size()));
+          // recurse on variable?
+          Node var = curr->d_vars[vindex];
+          bool recurse = true;
+          // check if it is already bound
+          std::map<Node, Node>::iterator its = smap.find(var);
+          if (its != smap.end())
+          {
+            if (its->second != cn)
+            {
+              recurse = false;
+            }
+          }
+          else
+          {
+            vars.push_back(var);
+            subs.push_back(cn);
+            smap[var] = cn;
+            visit_bound_var[index] = true;
+          }
+          if (recurse)
+          {
+            Trace("sse-match-debug") << "recurse var : " << var << std::endl;
+            cvisit.pop_back();
+            visit.push_back(cvisit);
+            visit_trie.push_back(&curr->d_children[var][0]);
+            visit_var_index.push_back(-1);
+            visit_bound_var.push_back(false);
+          }
+          visit_var_index[index]++;
+        }
+      }
+    }
+  }
+  return true;
+}
+
+void MatchTrie::addTerm(Node n)
+{
+  std::vector<Node> visit;
+  visit.push_back(n);
+  MatchTrie* curr = this;
+  while (!visit.empty())
+  {
+    Node cn = visit.back();
+    visit.pop_back();
+    if (cn.hasOperator())
+    {
+      curr = &(curr->d_children[cn.getOperator()][cn.getNumChildren()]);
+      for (const Node& cnc : cn)
+      {
+        visit.push_back(cnc);
+      }
+    }
+    else
+    {
+      if (cn.isVar()
+          && std::find(curr->d_vars.begin(), curr->d_vars.end(), cn)
+                 == curr->d_vars.end())
+      {
+        curr->d_vars.push_back(cn);
+      }
+      curr = &(curr->d_children[cn][0]);
+    }
+  }
+  curr->d_data = n;
+}
+
+void MatchTrie::clear()
+{
+  d_children.clear();
+  d_vars.clear();
+  d_data = Node::null();
+}
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */
index 4bc10075dea0e8475f09b887f2fd1dbe67cf181f..fa0d670d27bb2fbff65b5720705bef8e1568ffe4 100644 (file)
@@ -340,42 +340,149 @@ class SygusSampler : public LazyTrieEvaluator
   void registerSygusType(TypeNode tn);
 };
 
+/** A virtual class for notifications regarding matches. */
+class NotifyMatch
+{
+ public:
+  /**
+   * A notification that s is equal to n * { vars -> subs }. This function
+   * should return false if we do not wish to be notified of further matches.
+   */
+  virtual bool notify(Node s,
+                      Node n,
+                      std::vector<Node>& vars,
+                      std::vector<Node>& subs) = 0;
+};
+
+/**
+ * A trie (discrimination tree) storing a set of terms S, that can be used to
+ * query, for a given term t, all terms from S that are matchable with t.
+ */
+class MatchTrie
+{
+ public:
+  /** Get matches
+   *
+   * This calls ntm->notify( n, s, vars, subs ) for each term s stored in this
+   * trie that is matchable with n where s = n * { vars -> subs } for some
+   * vars, subs. This function returns false if one of these calls to notify
+   * returns false.
+   */
+  bool getMatches(Node n, NotifyMatch* ntm);
+  /** Adds node n to this trie */
+  void addTerm(Node n);
+  /** Clear this trie */
+  void clear();
+
+ private:
+  /**
+   * The children of this node in the trie. Terms t are indexed by a
+   * depth-first (right to left) traversal on its subterms, where the
+   * top-symbol of t is indexed by:
+   * - (operator, #children) if t has an operator, or
+   * - (t, 0) if t does not have an operator.
+   */
+  std::map<Node, std::map<unsigned, MatchTrie> > d_children;
+  /** The set of variables in the domain of d_children */
+  std::vector<Node> d_vars;
+  /** The data of this node in the trie */
+  Node d_data;
+};
+
 /** Version of the above class with some additional features */
 class SygusSamplerExt : public SygusSampler
 {
  public:
+  SygusSamplerExt();
   /** initialize extended */
   void initializeSygusExt(QuantifiersEngine* qe,
                           Node f,
                           unsigned nsamples,
                           bool useSygusType);
   /** register term n with this sampler database
+   *
+   *  For each call to registerTerm( t, ... ) that returns s, we say that
+   * (t,s) and (s,t) are "relevant pairs".
    *
    * This returns either null, or a term ret with the same guarantees as
    * SygusSampler::registerTerm with the additional guarantee
-   * that for all ret' returned by a previous call to registerTerm( n' ),
-   * we have that n = ret is not alpha-equivalent to n' = ret'
+   * that for all previous relevant pairs ( n', nret' ),
+   * we have that n = ret is not an instance of n' = ret'
    * modulo symmetry of equality, nor is n = ret derivable from the set of
-   * all previous input/output pairs based on the d_drewrite utility.
-   * For example,
-   *   (t+0), t and (s+0), s
-   * will not both be input/output pairs of this function since t+0=t is
-   * alpha-equivalent to s+0=s.
-   *   s, t and s+0, t+0
-   * will not both be input/output pairs of this function since s+0=t+0 is
+   * all previous relevant pairs. The latter is determined by the d_drewrite
+   * utility. For example:
+   * [1]  ( t+0, t ) and ( x+0, x )
+   * will not both be relevant pairs of this function since t+0=t is
+   * an instance of x+0=x.
+   * [2]  ( s, t ) and ( s+0, t+0 )
+   * will not both be relevant pairs of this function since s+0=t+0 is
    * derivable from s=t.
+   * These two criteria may be combined, for example:
+   * [3] ( t+0, s ) is not a relevant pair if both ( x+0, x+s ) and ( t+s, s )
+   * are relevant pairs, since t+0 is an instance of x+0 where
+   * { x |-> t }, and x+s { x |-> t } = s is derivable, via the third pair
+   * above (t+s = s).
    *
    * If this function returns null, then n is equivalent to a previously
-   * registered term ret, and the equality n = ret is either alpha-equivalent
-   * to a previous input/output pair n' = ret', or n = ret is derivable
-   * from the set of all previous input/output pairs based on the
-   * d_drewrite utility.
+   * registered term ret, and the equality ( n, ret ) is either an instance
+   * of a previous relevant pair ( n', ret' ), or n = ret is derivable
+   * from the set of all previous relevant pairs based on the
+   * d_drewrite utility, or is an instance of a previous pair
    */
   Node registerTerm(Node n, bool forceKeep = false) override;
 
  private:
   /** dynamic rewriter class */
   std::unique_ptr<DynamicRewriter> d_drewrite;
+
+  //----------------------------match filtering
+  /**
+   * Stores all relevant pairs returned by this sampler (see registerTerm). In
+   * detail, if (t,s) is a relevant pair, then t in d_pairs[s].
+   */
+  std::map<Node, std::unordered_set<Node, NodeHashFunction> > d_pairs;
+  /** Match trie storing all terms in the domain of d_pairs. */
+  MatchTrie d_match_trie;
+  /** Notify class */
+  class SygusSamplerExtNotifyMatch : public NotifyMatch
+  {
+    SygusSamplerExt& d_sse;
+
+   public:
+    SygusSamplerExtNotifyMatch(SygusSamplerExt& sse) : d_sse(sse) {}
+    /** notify match */
+    bool notify(Node n,
+                Node s,
+                std::vector<Node>& vars,
+                std::vector<Node>& subs) override
+    {
+      return d_sse.notify(n, s, vars, subs);
+    }
+  };
+  /** Notify object used for reporting matches from d_match_trie */
+  SygusSamplerExtNotifyMatch d_ssenm;
+  /**
+   * Stores the current right hand side of a pair we are considering.
+   *
+   * In more detail, in registerTerm, we are interested in whether a pair (s,t)
+   * is a relevant pair. We do this by:
+   * (1) Setting the node d_curr_pair_rhs to t,
+   * (2) Using d_match_trie, compute all terms s1...sn that match s.
+   * For each si, where s = si * sigma for some substitution sigma, we check
+   * whether t = ti * sigma for some previously relevant pair (si,ti), in
+   * which case (s,t) is an instance of (si,ti).
+   */
+  Node d_curr_pair_rhs;
+  /**
+   * Called by the above class during d_match_trie.getMatches( s ), when we
+   * find that si = s * sigma, where si is a term that is stored in
+   * d_match_trie.
+   *
+   * This function returns false if ( s, d_curr_pair_rhs ) is an instance of
+   * previously relevant pair.
+   */
+  bool notify(Node s, Node n, std::vector<Node>& vars, std::vector<Node>& subs);
+  //----------------------------end match filtering
 };
 
 } /* CVC4::theory::quantifiers namespace */