Refactor extended rewriter, move rewrites to aggressive (#2387)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 28 Aug 2018 01:08:01 +0000 (20:08 -0500)
committerAndres Noetzli <andres.noetzli@gmail.com>
Tue, 28 Aug 2018 01:08:01 +0000 (18:08 -0700)
This is work towards #2305.

With this PR, CVC4's performance is fairly reasonable on the Kind2 BMC benchmarks with --decision=internal --ext-rew-prep --ext-rew-prep-agg.

src/theory/quantifiers/extended_rewrite.cpp
src/theory/quantifiers/extended_rewrite.h
test/regress/regress0/fp/ext-rew-test.smt2

index cdd597a5cc1290ad47cfac411b787c168457d364..21a7fe29cefb1f00836c86fe392376abc71ccc15 100644 (file)
@@ -35,7 +35,10 @@ typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
 
 ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
 {
+  d_true = NodeManager::currentNM()->mkConst(true);
+  d_false = NodeManager::currentNM()->mkConst(false);
 }
+
 void ExtendedRewriter::setCache(Node n, Node ret)
 {
   ExtRewriteAttribute era;
@@ -74,30 +77,33 @@ Node ExtendedRewriter::extendedRewrite(Node n)
   NodeManager* nm = NodeManager::currentNM();
 
   //--------------------pre-rewrite
-  Node pre_new_ret;
-  if (ret.getKind() == IMPLIES)
-  {
-    pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
-    debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
-  }
-  else if (ret.getKind() == XOR)
-  {
-    pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
-    debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
-  }
-  else if (ret.getKind() == NOT)
+  if (d_aggr)
   {
-    pre_new_ret = extendedRewriteNnf(ret);
-    debugExtendedRewrite(ret, pre_new_ret, "NNF");
-  }
-  if (!pre_new_ret.isNull())
-  {
-    ret = extendedRewrite(pre_new_ret);
+    Node pre_new_ret;
+    if (ret.getKind() == IMPLIES)
+    {
+      pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
+      debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
+    }
+    else if (ret.getKind() == XOR)
+    {
+      pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
+      debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
+    }
+    else if (ret.getKind() == NOT)
+    {
+      pre_new_ret = extendedRewriteNnf(ret);
+      debugExtendedRewrite(ret, pre_new_ret, "NNF");
+    }
+    if (!pre_new_ret.isNull())
+    {
+      ret = extendedRewrite(pre_new_ret);
 
-    Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> "
-                                 << pre_new_ret << std::endl;
-    setCache(n, ret);
-    return ret;
+      Trace("q-ext-rewrite-debug")
+          << "...ext-pre-rewrite : " << n << " -> " << pre_new_ret << std::endl;
+      setCache(n, ret);
+      return ret;
+    }
   }
   //--------------------end pre-rewrite
 
@@ -175,6 +181,7 @@ Node ExtendedRewriter::extendedRewrite(Node n)
     new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
     debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
   }
+  Assert(new_ret.isNull() || new_ret != ret);
   if (new_ret.isNull() && ret.getKind() != ITE)
   {
     // simple ITE pulling
@@ -217,6 +224,13 @@ Node ExtendedRewriter::extendedRewrite(Node n)
   }
   Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
                                << std::endl;
+  if (Trace.isOn("q-ext-rewrite-nf"))
+  {
+    if (n == ret)
+    {
+      Trace("q-ext-rewrite-nf") << "ext-rew normal form : " << n << std::endl;
+    }
+  }
   setCache(n, ret);
   return ret;
 }
@@ -388,6 +402,35 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
       break;
     }
   }
+  if (new_ret.isNull())
+  {
+    // merging branches
+    for (unsigned i = 1; i <= 2; i++)
+    {
+      if (n[i].getKind() == ITE)
+      {
+        Node no = n[3 - i];
+        for (unsigned j = 1; j <= 2; j++)
+        {
+          if (n[i][j] == no)
+          {
+            // e.g.
+            // ite( C1, ite( C2, t1, t2 ), t1 ) ----> ite( C1 ^ ~C2, t2, t1 )
+            Node nc1 = i == 2 ? n[0].negate() : n[0];
+            Node nc2 = j == 1 ? n[i][0].negate() : n[i][0];
+            Node new_cond = nm->mkNode(AND, nc1, nc2);
+            new_ret = nm->mkNode(ITE, new_cond, n[i][3 - j], no);
+            ss_reason << "ITE merge branch";
+            break;
+          }
+        }
+      }
+      if (!new_ret.isNull())
+      {
+        break;
+      }
+    }
+  }
 
   if (new_ret.isNull() && d_aggr)
   {
@@ -395,10 +438,7 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
     // substitution to the children of ite( x = t ^ C, s, t ) below.
     std::vector<Node> vars;
     std::vector<Node> subs;
-    for (const Node& eq : eq_conds)
-    {
-      inferSubstitution(eq, vars, subs);
-    }
+    inferSubstitution(n[0], vars, subs, true);
 
     if (!vars.empty())
     {
@@ -437,6 +477,27 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
         }
       }
     }
+    if (new_ret.isNull())
+    {
+      // ite( C, t, s ) ----> ite( C, t, s { C -> false } )
+      TNode tv = n[0];
+      TNode ts = d_false;
+      Node nn = t2.substitute(tv, ts);
+      if (nn != t2)
+      {
+        nn = Rewriter::rewrite(nn);
+        if (nn == t1)
+        {
+          new_ret = nn;
+          ss_reason << "ITE subs invariant false";
+        }
+        else if (full || nn.isConst())
+        {
+          new_ret = nm->mkNode(itek, n[0], t1, nn);
+          ss_reason << "ITE subs false";
+        }
+      }
+    }
   }
 
   // only print debug trace if full=true
@@ -450,6 +511,11 @@ Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
 
 Node ExtendedRewriter::extendedRewriteAndOr(Node n)
 {
+  // all the below rewrites are aggressive
+  if (!d_aggr)
+  {
+    return Node::null();
+  }
   Node new_ret;
   // all kinds are legal to substitute over : hence we give the empty map
   std::map<Kind, bool> bcp_kinds;
@@ -475,6 +541,7 @@ Node ExtendedRewriter::extendedRewriteAndOr(Node n)
 
 Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
 {
+  Assert(n.getKind() != ITE);
   NodeManager* nm = NodeManager::currentNM();
   TypeNode tn = n.getType();
   std::vector<Node> children;
@@ -491,7 +558,9 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
   std::map<unsigned, std::map<unsigned, Node> > ite_c;
   for (unsigned i = 0; i < nchildren; i++)
   {
-    if (n[i].getKind() == itek)
+    // only pull ITEs apart if we are aggressive
+    if (n[i].getKind() == itek
+        && (d_aggr || (n[i][1].getKind() != ITE && n[i][2].getKind() != ITE)))
     {
       unsigned ii = hasOp ? i + 1 : i;
       for (unsigned j = 0; j < 2; j++)
@@ -510,8 +579,17 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
         debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
         return ite_c[i][0];
       }
-      else if (d_aggr)
+      if (d_aggr)
       {
+        if (nchildren == 2 && (n[1 - i].isVar() || n[1 - i].isConst())
+            && !n[1 - i].getType().isBoolean() && tn.isBoolean())
+        {
+          // always pull variable or constant with binary (theory) predicate
+          // e.g. P( x, ite( A, t1, t2 ) ) ---> ite( A, P( x, t1 ), P( x, t2 ) )
+          Node new_ret = nm->mkNode(ITE, n[i][0], ite_c[i][0], ite_c[i][1]);
+          debugExtendedRewrite(n, new_ret, "ITE pull var predicate");
+          return new_ret;
+        }
         for (unsigned j = 0; j < 2; j++)
         {
           Node pullr = ite_c[i][j];
@@ -522,7 +600,7 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
             // implies
             // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
             Node new_ret;
-            if (tn.isBoolean())
+            if (tn.isBoolean() && pullr.isConst())
             {
               // remove false/true child immediately
               bool pol = pullr.getConst<bool>();
@@ -544,33 +622,35 @@ Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
       }
     }
   }
-
-  for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
+  if (d_aggr)
   {
-    Node nite = n[ip.first];
-    Assert(nite.getKind() == itek);
-    // now, simply pull the ITE and try ITE rewrites
-    Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
-    pull_ite = Rewriter::rewrite(pull_ite);
-    if (pull_ite.getKind() == ITE)
+    for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
     {
-      Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
-      if (!new_pull_ite.isNull())
+      Node nite = n[ip.first];
+      Assert(nite.getKind() == itek);
+      // now, simply pull the ITE and try ITE rewrites
+      Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
+      pull_ite = Rewriter::rewrite(pull_ite);
+      if (pull_ite.getKind() == ITE)
       {
-        debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
-        return new_pull_ite;
+        Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
+        if (!new_pull_ite.isNull())
+        {
+          debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
+          return new_pull_ite;
+        }
+      }
+      else
+      {
+        // A general rewrite could eliminate the ITE by pulling.
+        // An example is:
+        //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
+        //   ite( C, ~~x, ite( C, y, x ) ) --->
+        //   x
+        // where ~ is bitvector negation.
+        debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
+        return pull_ite;
       }
-    }
-    else
-    {
-      // A general rewrite could eliminate the ITE by pulling.
-      // An example is:
-      //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
-      //   ite( C, ~~x, ite( C, y, x ) ) --->
-      //   x
-      // where ~ is bitvector negation.
-      debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
-      return pull_ite;
     }
   }
 
@@ -1475,8 +1555,19 @@ Node ExtendedRewriter::solveEquality(Node n)
 
 bool ExtendedRewriter::inferSubstitution(Node n,
                                          std::vector<Node>& vars,
-                                         std::vector<Node>& subs)
+                                         std::vector<Node>& subs,
+                                         bool usePred)
 {
+  if (n.getKind() == AND)
+  {
+    bool ret = false;
+    for (const Node& nc : n)
+    {
+      bool cret = inferSubstitution(nc, vars, subs, usePred);
+      ret = ret || cret;
+    }
+    return ret;
+  }
   if (n.getKind() == EQUAL)
   {
     // see if it can be put into form x = y
@@ -1525,6 +1616,13 @@ bool ExtendedRewriter::inferSubstitution(Node n,
       }
     }
   }
+  if (usePred)
+  {
+    bool negated = n.getKind() == NOT;
+    vars.push_back(negated ? n[0] : n);
+    subs.push_back(negated ? d_false : d_true);
+    return true;
+  }
   return false;
 }
 
index 4d3f08b1d2995a8f9e8e9cd653a9b47660b861bb..29f3b7bb30e66c947a570a154db3b32e6589f940 100644 (file)
@@ -64,6 +64,9 @@ class ExtendedRewriter
    * may be applied as a preprocessing step.
    */
   bool d_aggr;
+  /** true/false nodes */
+  Node d_true;
+  Node d_false;
   /** cache that the extended rewritten form of n is ret */
   void setCache(Node n, Node ret);
   /** add to children
@@ -213,10 +216,13 @@ class ExtendedRewriter
    * (2) a variable y such that x < y based on an ordering,
    * then this method adds x to vars and y to subs and return true, otherwise
    * it returns false.
+   * If usePred is true, we may additionally add n -> true, or n[0] -> false
+   * is n is a negation.
    */
   bool inferSubstitution(Node n,
                          std::vector<Node>& vars,
-                         std::vector<Node>& subs);
+                         std::vector<Node>& subs,
+                         bool usePred = false);
   /** extended rewrite
    *
    * Prints debug information, indicating the rewrite n ---> ret was found.
index 785c654ef6c9051aae32e8078e29da7d97f50b7e..3fb3a9e53348c607f6f825fc4729a148e1bf92ab 100644 (file)
@@ -1,4 +1,4 @@
-; COMMAND-LINE: --ext-rew-prep
+; COMMAND-LINE: --ext-rew-prep --ext-rew-prep-agg
 ; EXPECT: unsat
 (set-info :smt-lib-version 2.6)
 (set-logic QF_FP)