Option to use extended rewriter as a preprocessing pass (#1600)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 13 Feb 2018 00:16:59 +0000 (18:16 -0600)
committerGitHub <noreply@github.com>
Tue, 13 Feb 2018 00:16:59 +0000 (18:16 -0600)
src/options/smt_options
src/smt/smt_engine.cpp
src/theory/quantifiers/extended_rewrite.cpp
src/theory/quantifiers/extended_rewrite.h
src/theory/quantifiers/term_database_sygus.cpp

index b19420060f76d0361aa7ab00f803c058a021f6f7..72189ea134a87f318ffed6c2ad43edca95c95115 100644 (file)
@@ -70,6 +70,11 @@ option doITESimp --ite-simp bool :read-write
 
 option doITESimpOnRepeat --on-repeat-ite-simp bool :read-write :default false
  do the ite simplification pass again if repeating simplification
+option extRewPrep --ext-rew-prep bool :read-write :default false
+ use extended rewriter as a preprocessing pass
+option extRewPrepAgg --ext-rew-prep-agg bool :read-write :default false
+ use aggressive extended rewriter as a preprocessing pass
 
 option simplifyWithCareEnabled --simp-with-care bool :default false :read-write
  enables simplifyWithCare in ite simplificiation
index fdd72ba2e21d2829f2532bc0818d0acb0218e275..7e2f6c38c40805f09ff245a364a4636e1687aa18 100644 (file)
@@ -4293,6 +4293,16 @@ void SmtEnginePrivate::processAssertions() {
 
   bool noConflict = true;
 
+  if (options::extRewPrep())
+  {
+    theory::quantifiers::ExtendedRewriter extr(options::extRewPrepAgg());
+    for (unsigned i = 0; i < d_assertions.size(); ++i)
+    {
+      Node a = d_assertions[i];
+      d_assertions.replace(i, extr.extendedRewrite(a));
+    }
+  }
+
   // Unconstrained simplification
   if(options::unconstrainedSimp()) {
     Trace("smt-proc") << "SmtEnginePrivate::processAssertions() : pre-unconstrained-simp" << endl;
index ba0860d38f452b2b89c1a5d0a51f3e375643d04e..dd4fc86ba908ff49838337481d521062a7f0dbd7 100644 (file)
@@ -26,7 +26,7 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-ExtendedRewriter::ExtendedRewriter()
+ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
@@ -76,8 +76,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Node n)
           Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to "
                                  << nc << " by simple ITE pulling."
                                  << std::endl;
-          // recurse
-          return extendedRewrite(nc);
+          return nc;
         }
       }
     }
@@ -87,205 +86,201 @@ Node ExtendedRewriter::extendedRewritePullIte(Node n)
 
 Node ExtendedRewriter::extendedRewrite(Node n)
 {
+  n = Rewriter::rewrite(n);
   std::unordered_map<Node, Node, NodeHashFunction>::iterator it =
       d_ext_rewrite_cache.find(n);
-  if (it == d_ext_rewrite_cache.end())
+  if (it != d_ext_rewrite_cache.end())
   {
-    Node ret = n;
-    if (n.getNumChildren() > 0)
+    return it->second;
+  }
+  Node ret = n;
+  if (n.getNumChildren() > 0)
+  {
+    std::vector<Node> children;
+    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
     {
-      std::vector<Node> children;
-      if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
-      {
-        children.push_back(n.getOperator());
-      }
-      bool childChanged = false;
-      for (unsigned i = 0; i < n.getNumChildren(); i++)
-      {
-        Node nc = extendedRewrite(n[i]);
-        childChanged = nc != n[i] || childChanged;
-        children.push_back(nc);
-      }
-      // Some commutative operators have rewriters that are agnostic to order,
-      // thus, we sort here.
-      if (TermUtil::isComm(n.getKind()))
-      {
-        childChanged = true;
-        std::sort(children.begin(), children.end());
-      }
-      if (childChanged)
-      {
-        ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
-      }
+      children.push_back(n.getOperator());
     }
-    ret = Rewriter::rewrite(ret);
-    Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
-                                 << " (from " << n << ")" << std::endl;
+    bool childChanged = false;
+    for (unsigned i = 0; i < n.getNumChildren(); i++)
+    {
+      Node nc = extendedRewrite(n[i]);
+      childChanged = nc != n[i] || childChanged;
+      children.push_back(nc);
+    }
+    // Some commutative operators have rewriters that are agnostic to order,
+    // thus, we sort here.
+    if (TermUtil::isComm(n.getKind()) && (d_aggr || children.size() <= 5))
+    {
+      childChanged = true;
+      std::sort(children.begin(), children.end());
+    }
+    if (childChanged)
+    {
+      ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
+    }
+  }
+  ret = Rewriter::rewrite(ret);
+  Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
+                               << " (from " << n << ")" << std::endl;
 
-    Node new_ret;
-    if (ret.getKind() == kind::EQUAL)
+  Node new_ret;
+  if (ret.getKind() == kind::EQUAL)
+  {
+    if (new_ret.isNull())
     {
-      if (new_ret.isNull())
-      {
-        // simple ITE pulling
-        new_ret = extendedRewritePullIte(ret);
-      }
+      // simple ITE pulling
+      new_ret = extendedRewritePullIte(ret);
+    }
+  }
+  else if (ret.getKind() == kind::ITE)
+  {
+    Assert(ret[1] != ret[2]);
+    if (ret[0].getKind() == NOT)
+    {
+      ret = NodeManager::currentNM()->mkNode(
+          kind::ITE, ret[0][0], ret[2], ret[1]);
     }
-    else if (ret.getKind() == kind::ITE)
+    if (ret[0].getKind() == kind::EQUAL)
     {
-      Assert(ret[1] != ret[2]);
-      if (ret[0].getKind() == NOT)
+      // simple invariant ITE
+      for (unsigned i = 0; i < 2; i++)
       {
-        ret = NodeManager::currentNM()->mkNode(
-            kind::ITE, ret[0][0], ret[2], ret[1]);
+        if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i])
+        {
+          Trace("q-ext-rewrite")
+              << "sygus-extr : " << ret << " rewrites to " << ret[2]
+              << " due to simple invariant ITE." << std::endl;
+          new_ret = ret[2];
+          break;
+        }
       }
-      if (ret[0].getKind() == kind::EQUAL)
+      // notice this is strictly more general than the above
+      if (new_ret.isNull())
       {
-        // simple invariant ITE
+        // simple substitution
         for (unsigned i = 0; i < 2; i++)
         {
-          if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i])
+          TNode r1 = ret[0][i];
+          TNode r2 = ret[0][1 - i];
+          if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
           {
-            Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to "
-                                   << ret[2] << " due to simple invariant ITE."
-                                   << std::endl;
-            new_ret = ret[2];
-            break;
-          }
-        }
-        // notice this is strictly more general than the above
-        if (new_ret.isNull())
-        {
-          // simple substitution
-          for (unsigned i = 0; i < 2; i++)
-          {              
-            TNode r1 = ret[0][i];
-            TNode r2 = ret[0][1 - i];
-            if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
+            Node retn = ret[1].substitute(r1, r2);
+            if (retn != ret[1])
             {
-              Node retn = ret[1].substitute(r1, r2);
-              if (retn != ret[1])
-              {
-                new_ret = NodeManager::currentNM()->mkNode(
-                    kind::ITE, ret[0], retn, ret[2]);
-                Trace("q-ext-rewrite")
-                    << "sygus-extr : " << ret << " rewrites to " << new_ret
-                    << " due to simple ITE substitution." << std::endl;
-              }
+              new_ret = NodeManager::currentNM()->mkNode(
+                  kind::ITE, ret[0], retn, ret[2]);
+              Trace("q-ext-rewrite")
+                  << "sygus-extr : " << ret << " rewrites to " << new_ret
+                  << " due to simple ITE substitution." << std::endl;
             }
           }
         }
       }
     }
-    else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION
-             || ret.getKind() == INTS_MODULUS)
+  }
+  else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION
+           || ret.getKind() == INTS_MODULUS)
+  {
+    // rewrite as though total
+    std::vector<Node> children;
+    bool all_const = true;
+    for (unsigned i = 0; i < ret.getNumChildren(); i++)
     {
-      // rewrite as though total
-      std::vector<Node> children;
-      bool all_const = true;
-      for (unsigned i = 0; i < ret.getNumChildren(); i++)
+      if (ret[i].isConst())
       {
-        if (ret[i].isConst())
-        {
-          children.push_back(ret[i]);
-        }
-        else
-        {
-          all_const = false;
-          break;
-        }
+        children.push_back(ret[i]);
       }
-      if (all_const)
+      else
       {
-        Kind new_k =
-            (ret.getKind() == DIVISION
-                 ? DIVISION_TOTAL
-                 : (ret.getKind() == INTS_DIVISION ? INTS_DIVISION_TOTAL
-                                                   : INTS_MODULUS_TOTAL));
-        new_ret = NodeManager::currentNM()->mkNode(new_k, children);
-        Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to "
-                               << new_ret << " due to total interpretation."
-                               << std::endl;
+        all_const = false;
+        break;
       }
     }
-    // more expensive rewrites
-    if (new_ret.isNull())
+    if (all_const)
     {
-      Trace("q-ext-rewrite-debug2") << "Do expensive rewrites on " << ret
-                                    << std::endl;
-      bool polarity = ret.getKind() != NOT;
-      Node ret_atom = ret.getKind() == NOT ? ret[0] : ret;
-      if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
-          || ret_atom.getKind() == GEQ)
+      Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
+                                              : (ret.getKind() == INTS_DIVISION
+                                                     ? INTS_DIVISION_TOTAL
+                                                     : INTS_MODULUS_TOTAL));
+      new_ret = NodeManager::currentNM()->mkNode(new_k, children);
+      Trace("q-ext-rewrite")
+          << "sygus-extr : " << ret << " rewrites to " << new_ret
+          << " due to total interpretation." << std::endl;
+    }
+  }
+  // more expensive rewrites
+  if (new_ret.isNull() && d_aggr)
+  {
+    new_ret = extendedRewriteAggr(ret);
+  }
+
+  d_ext_rewrite_cache[n] = ret;
+  if (!new_ret.isNull())
+  {
+    ret = extendedRewrite(new_ret);
+  }
+  d_ext_rewrite_cache[n] = ret;
+  return ret;
+}
+
+Node ExtendedRewriter::extendedRewriteAggr(Node n)
+{
+  Node new_ret;
+  Trace("q-ext-rewrite-debug2")
+      << "Do aggressive rewrites on " << n << std::endl;
+  bool polarity = n.getKind() != NOT;
+  Node ret_atom = n.getKind() == NOT ? n[0] : n;
+  if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
+      || ret_atom.getKind() == GEQ)
+  {
+    Trace("q-ext-rewrite-debug2")
+        << "Compute monomial sum " << ret_atom << std::endl;
+    // compute monomial sum
+    std::map<Node, Node> msum;
+    if (ArithMSum::getMonomialSumLit(ret_atom, msum))
+    {
+      for (std::map<Node, Node>::iterator itm = msum.begin(); itm != msum.end();
+           ++itm)
       {
-        Trace("q-ext-rewrite-debug2") << "Compute monomial sum " << ret_atom
-                                      << std::endl;
-        // compute monomial sum
-        std::map<Node, Node> msum;
-        if (ArithMSum::getMonomialSumLit(ret_atom, msum))
+        Node v = itm->first;
+        Trace("q-ext-rewrite-debug2")
+            << itm->first << " * " << itm->second << std::endl;
+        if (v.getKind() == ITE)
         {
-          for (std::map<Node, Node>::iterator itm = msum.begin();
-               itm != msum.end();
-               ++itm)
+          Node veq;
+          int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
+          if (res != 0)
           {
-            Node v = itm->first;
-            Trace("q-ext-rewrite-debug2") << itm->first << " * " << itm->second
-                                          << std::endl;
-            if (v.getKind() == ITE)
+            Trace("q-ext-rewrite-debug")
+                << "  have ITE relation, solved form : " << veq << std::endl;
+            // try pulling ITE
+            new_ret = extendedRewritePullIte(veq);
+            if (!new_ret.isNull())
             {
-              Node veq;
-              int res = ArithMSum::isolate(v, msum, veq, ret_atom.getKind());
-              if (res != 0)
+              if (!polarity)
               {
-                Trace("q-ext-rewrite-debug")
-                    << "  have ITE relation, solved form : " << veq
-                    << std::endl;
-                // try pulling ITE
-                new_ret = extendedRewritePullIte(veq);
-                if (!new_ret.isNull())
-                {
-                  if (!polarity)
-                  {
-                    new_ret = new_ret.negate();
-                  }
-                  break;
-                }
-              }
-              else
-              {
-                Trace("q-ext-rewrite-debug") << "  failed to isolate " << v
-                                             << " in " << ret << std::endl;
+                new_ret = new_ret.negate();
               }
+              break;
             }
           }
-        }
-        else
-        {
-          Trace("q-ext-rewrite-debug") << "  failed to get monomial sum of "
-                                       << ret << std::endl;
+          else
+          {
+            Trace("q-ext-rewrite-debug")
+                << "  failed to isolate " << v << " in " << n << std::endl;
+          }
         }
       }
-      else if (ret_atom.getKind() == ITE)
-      {
-        // TODO : conditional rewriting
-      }
-      else if (ret.getKind() == kind::AND || ret.getKind() == kind::OR)
-      {
-        // TODO condition merging
-      }
     }
-
-    if (!new_ret.isNull())
+    else
     {
-      ret = Rewriter::rewrite(new_ret);
+      Trace("q-ext-rewrite-debug")
+          << "  failed to get monomial sum of " << n << std::endl;
     }
-    d_ext_rewrite_cache[n] = ret;
-    return ret;
-  }
-  else
-  {
-    return it->second;
   }
+  // TODO (#1599) : conditional rewriting, condition merging
+  return new_ret;
 }
 
 } /* CVC4::theory::quantifiers namespace */
index 3a9fdb918291195a47f781eb982746eef5eef53b..25d710a6b05c2e661cef51d55ceb7e0f66cc7962 100644 (file)
@@ -43,12 +43,23 @@ namespace quantifiers {
 class ExtendedRewriter
 {
  public:
-  ExtendedRewriter();
+  ExtendedRewriter(bool aggr = true);
   ~ExtendedRewriter() {}
   /** return the extended rewritten form of n */
   Node extendedRewrite(Node n);
 
  private:
+  /**
+   * Whether this extended rewriter applies aggressive rewriting techniques,
+   * which are more expensive. Examples of aggressive rewriting include:
+   * - conditional rewriting,
+   * - condition merging,
+   * - sorting childing of commutative operators with more than 5 children.
+   *
+   * Aggressive rewriting is applied for SyGuS, whereas non-aggressive rewriting
+   * may be applied as a preprocessing step.
+   */
+  bool d_aggr;
   /** true and false nodes */
   Node d_true;
   Node d_false;
@@ -62,6 +73,8 @@ class ExtendedRewriter
    * where ---->^E denotes extended rewriting.
    */
   Node extendedRewritePullIte(Node n);
+  /** extended rewrite aggressive */
+  Node extendedRewriteAggr(Node n);
 };
 
 } /* CVC4::theory::quantifiers namespace */
index cda652ee7deee0dc85448f5318113469966a23bb..4c80108f15266b2a9fdef7823745445ae2c90a5b 100644 (file)
@@ -33,7 +33,7 @@ namespace quantifiers {
 TermDbSygus::TermDbSygus(context::Context* c, QuantifiersEngine* qe)
     : d_quantEngine(qe),
       d_syexp(new SygusExplain(this)),
-      d_ext_rw(new ExtendedRewriter)
+      d_ext_rw(new ExtendedRewriter(true))
 {
   d_true = NodeManager::currentNM()->mkConst( true );
   d_false = NodeManager::currentNM()->mkConst( false );