Make candidate rewrite match filtering handle polymorphic operators properly (#2236)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 1 Aug 2018 06:31:14 +0000 (01:31 -0500)
committerAndres Noetzli <andres.noetzli@gmail.com>
Wed, 1 Aug 2018 06:31:14 +0000 (23:31 -0700)
Currently, the discrimination tree index used for candidate rewrite rule filtering based on matching does not properly distinguish polymorphic operators, which leads to type errors. This makes the index handle them correctly.

Fixes #1923.

src/theory/quantifiers/candidate_rewrite_filter.cpp
src/theory/quantifiers/candidate_rewrite_filter.h
src/theory/quantifiers/dynamic_rewrite.cpp
src/theory/quantifiers/dynamic_rewrite.h

index 118c073a084074b759d492076248ca58bce494f8..bf2248e259515a2f3e2de9f1abf5a5a638692db1 100644 (file)
@@ -219,15 +219,12 @@ void CandidateRewriteFilter::initialize(SygusSampler* ss,
   // initialize members of this class
   d_match_trie.clear();
   d_pairs.clear();
-  if (options::sygusRewSynthFilterCong())
-  {
-    // initialize the dynamic rewriter
-    std::stringstream ss;
-    ss << "_dyn_rewriter_" << drewrite_counter;
-    drewrite_counter++;
-    d_drewrite = std::unique_ptr<DynamicRewriter>(
-        new DynamicRewriter(ss.str(), &d_fake_context));
-  }
+  // (re)initialize the dynamic rewriter
+  std::stringstream ssn;
+  ssn << "_dyn_rewriter_" << drewrite_counter;
+  drewrite_counter++;
+  d_drewrite = std::unique_ptr<DynamicRewriter>(
+      new DynamicRewriter(ssn.str(), &d_fake_context));
 }
 
 bool CandidateRewriteFilter::filterPair(Node n, Node eq_n)
@@ -285,7 +282,7 @@ bool CandidateRewriteFilter::filterPair(Node n, Node eq_n)
   }
 
   // ----- check rewriting redundancy
-  if (keep && d_drewrite != nullptr)
+  if (keep && options::sygusRewSynthFilterCong())
   {
     Trace("cr-filter-debug") << "Check equal rewrite pair..." << std::endl;
     if (d_drewrite->areEqual(bn, beq_n))
@@ -296,14 +293,15 @@ bool CandidateRewriteFilter::filterPair(Node n, Node eq_n)
     }
   }
 
-  if (options::sygusRewSynthFilterMatch())
+  if (keep && options::sygusRewSynthFilterMatch())
   {
     // ----- check matchable
     // check whether the pair is matchable with a previous one
     d_curr_pair_rhs = beq_n;
     Trace("crf-match") << "CRF check matches : " << bn << " [rhs = " << beq_n
                        << "]..." << std::endl;
-    if (!d_match_trie.getMatches(bn, &d_ssenm))
+    Node bni = d_drewrite->toInternal(bn);
+    if (!d_match_trie.getMatches(bni, &d_ssenm))
     {
       keep = false;
       Trace("cr-filter") << "...redundant (matchable)" << std::endl;
@@ -340,7 +338,7 @@ void CandidateRewriteFilter::registerRelevantPair(Node n, Node eq_n)
     beq_n = d_tds->sygusToBuiltin(eq_n);
   }
   // ----- check rewriting redundancy
-  if (d_drewrite != nullptr)
+  if (options::sygusRewSynthFilterCong())
   {
     Trace("cr-filter-debug") << "Add rewrite pair..." << std::endl;
     Assert(!d_drewrite->areEqual(bn, beq_n));
@@ -357,7 +355,8 @@ void CandidateRewriteFilter::registerRelevantPair(Node n, Node eq_n)
       if (d_pairs.find(t) == d_pairs.end())
       {
         Trace("crf-match") << "CRF add term : " << t << std::endl;
-        d_match_trie.addTerm(t);
+        Node ti = d_drewrite->toInternal(t);
+        d_match_trie.addTerm(ti);
       }
       d_pairs[t].insert(to);
     }
@@ -369,7 +368,13 @@ bool CandidateRewriteFilter::notify(Node s,
                                     std::vector<Node>& vars,
                                     std::vector<Node>& subs)
 {
+  Trace("crf-match-debug") << "Got : " << s << " matches " << n << std::endl;
   Assert(!d_curr_pair_rhs.isNull());
+  // convert back to original forms
+  s = d_drewrite->toExternal(s);
+  Assert(!s.isNull());
+  n = d_drewrite->toExternal(n);
+  Assert(!n.isNull());
   std::map<Node, std::unordered_set<Node, NodeHashFunction> >::iterator it =
       d_pairs.find(n);
   if (Trace.isOn("crf-match"))
@@ -379,18 +384,29 @@ bool CandidateRewriteFilter::notify(Node s,
     for (unsigned i = 0, size = vars.size(); i < size; i++)
     {
       Trace("crf-match") << "    " << vars[i] << " -> " << subs[i] << std::endl;
-      // TODO (#1923) ensure that we use an internal representation to
-      // ensure polymorphism is handled correctly
-      Assert(vars[i].getType().isComparableTo(subs[i].getType()));
     }
   }
+#ifdef CVC4_ASSERTIONS
+  for (unsigned i = 0, size = vars.size(); i < size; i++)
+  {
+    // By using internal representation of terms, we ensure polymorphism is
+    // handled correctly.
+    Assert(vars[i].getType().isComparableTo(subs[i].getType()));
+  }
+#endif
+  // must convert the inferred substitution to original form
+  std::vector<Node> esubs;
+  for (const Node& s : subs)
+  {
+    esubs.push_back(d_drewrite->toExternal(s));
+  }
   Assert(it != d_pairs.end());
   for (const Node& nr : it->second)
   {
     Node nrs =
-        nr.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+        nr.substitute(vars.begin(), vars.end(), esubs.begin(), esubs.end());
     bool areEqual = (nrs == d_curr_pair_rhs);
-    if (!areEqual && d_drewrite != nullptr)
+    if (!areEqual && options::sygusRewSynthFilterCong())
     {
       // if dynamic rewriter is available, consult it
       areEqual = d_drewrite->areEqual(nrs, d_curr_pair_rhs);
index 9a09680cc490daf75b77922a82d9672fe5f3a027..ca071faa459e857e11848bfefd13fa322bcb6fe4 100644 (file)
@@ -165,7 +165,13 @@ class CandidateRewriteFilter
    * detail, if (t,s) is a relevant pair, then t in d_pairs[s].
    */
   std::map<Node, std::unordered_set<Node, NodeHashFunction> > d_pairs;
-  /** Match trie storing all terms in the domain of d_pairs. */
+  /** Match trie storing all terms in the domain of d_pairs.
+   *
+   * Notice that we store d_drewrite->toInternal(t) instead of t, for each
+   * term t, so that polymorphism is handled properly. In particular, this
+   * prevents matches between terms select( x, y ) and select( z, y ) where
+   * the type of x and z are different.
+   */
   MatchTrie d_match_trie;
   /** Notify class */
   class CandidateRewriteFilterNotifyMatch : public NotifyMatch
index ef1cb3a9d61aa1f9dd52435a2ed33859b10ffe53..d8c5ac7b5c135955031cb89810e9f5e0d8770c2f 100644 (file)
@@ -124,9 +124,20 @@ Node DynamicRewriter::toInternal(Node a)
     }
   }
   d_term_to_internal[a] = ret;
+  d_internal_to_term[ret] = a;
   return ret;
 }
 
+Node DynamicRewriter::toExternal(Node ai)
+{
+  std::map<Node, Node>::iterator it = d_internal_to_term.find(ai);
+  if (it != d_internal_to_term.end())
+  {
+    return it->second;
+  }
+  return Node::null();
+}
+
 Node DynamicRewriter::OpInternalSymTrie::getSymbol(Node n)
 {
   std::vector<TypeNode> ctypes;
index 75f668b1130acf4c93dbde1e486cf310e7f0f1a1..50bae0268931f1e0d95beaa57924af86d4deee49 100644 (file)
@@ -62,6 +62,17 @@ class DynamicRewriter
    * Check whether this class knows that the equality a = b holds.
    */
   bool areEqual(Node a, Node b);
+  /**
+   * Convert node a to its internal representation, which replaces all
+   * interpreted operators in a by a unique uninterpreted symbol.
+   */
+  Node toInternal(Node a);
+  /**
+   * Convert internal node ai to its original representation. It is the case
+   * that toExternal(toInternal(a))=a. If ai is not a term returned by
+   * toInternal, this method returns null.
+   */
+  Node toExternal(Node ai);
 
  private:
   /** index over argument types to function skolems
@@ -96,13 +107,10 @@ class DynamicRewriter
   };
   /** the internal operator symbol trie for this class */
   std::map<Node, OpInternalSymTrie> d_ois_trie;
-  /**
-   * Convert node a to its internal representation, which replaces all
-   * interpreted operators in a by a unique uninterpreted symbol.
-   */
-  Node toInternal(Node a);
   /** cache of the above function */
   std::map<Node, Node> d_term_to_internal;
+  /** inverse of the above map */
+  std::map<Node, Node> d_internal_to_term;
   /** stores congruence closure over terms given to this class. */
   eq::EqualityEngine d_equalityEngine;
   /** list of all equalities asserted to equality engine */