Use simpler EQUAL SCs for LSHR0, LSHR1, ASHR0, AHSR1, SHL0, SHL1. (#1482)
authorMathias Preiner <mathias.preiner@gmail.com>
Sat, 6 Jan 2018 02:05:41 +0000 (18:05 -0800)
committerGitHub <noreply@github.com>
Sat, 6 Jan 2018 02:05:41 +0000 (18:05 -0800)
src/theory/quantifiers/bv_inverter.cpp

index 7f2343df706abaaf9a79e325119502cc5b74c3aa..7407051be651b819fad416e1d4324c3a9d45b6c8 100644 (file)
@@ -229,6 +229,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t)
     else
     {
       /* x >= t
+       * with side condition:
        * true (no side condition)  */
       sc = nm->mkNode(NOT, nm->mkNode(k, x, t));
     }
@@ -250,6 +251,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t)
     else
     {
       /* x <= t
+       * with side condition:
        * true (no side condition)  */
       sc = nm->mkNode(NOT, nm->mkNode(k, x, t));
     }
@@ -283,6 +285,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
     else
     {
       /* x >= t
+       * with side condition:
        * true (no side condition)  */
       sc = nm->mkNode(NOT, nm->mkNode(k, x, t));
     }
@@ -306,6 +309,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
     else
     {
       /* x <= t
+       * with side condition:
        * true (no side condition)  */
       sc = nm->mkNode(NOT, nm->mkNode(k, x, t));
     }
@@ -1343,6 +1347,29 @@ static Node getScBvAndOr(bool pol,
   return sc;
 }
 
+static Node defaultShiftSc(Kind litk, Kind shk, Node s, Node t)
+{
+  unsigned w;
+  NodeBuilder<> nb(OR);
+  NodeManager *nm;
+
+  nm = NodeManager::currentNM();
+
+  w = bv::utils::getSize(s);
+  Assert(w == bv::utils::getSize(t));
+
+  nb << nm->mkNode(litk, s, t);
+  for (unsigned i = 1; i <= w; i++)
+  {
+    Node sw = bv::utils::mkConst(w, i);
+    nb << nm->mkNode(litk, nm->mkNode(shk, s, sw), t);
+  }
+  if (nb.getNumChildren() == 1)
+    return nb[0];
+  return nb.constructNode();
+}
+
+
 static Node getScBvLshr(bool pol,
                         Kind litk,
                         Kind k,
@@ -1370,29 +1397,11 @@ static Node getScBvLshr(bool pol,
       if (pol)
       {
         /* x >> s = t
-         * with side condition:
-         * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0)
-         * ->
-         * (or
-         *  (= s z)
-         *  (and
-         *   (bvult s w)
-         *   (= ((_ extract 2*w-1 w) (bvshl (concat z t) (concat z s))) z))
-         *  (and (bvuge s w) (= t z)))
-         * with w = getSize(t) = getSize(s)
-         * and z = 0 with getSize(z) = w  */
-        Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t);
-        Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s);
-        Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
-        Node ext = bv::utils::mkExtract(shl, 2*w-1, w);
-
-        Node o1 = s.eqNode(z);
-        Node o2 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
-        Node o3 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
-
-        scl = nm->mkNode(OR, o1, o2, o3);
+         * with side condition (synthesized):
+         * (= (bvlshr (bvshl t s) s) t)  */
+        Node shl = nm->mkNode(BITVECTOR_SHL, t, s);
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, shl, s);
+        scl = lshr.eqNode(t);
       }
       else
       {
@@ -1413,30 +1422,9 @@ static Node getScBvLshr(bool pol,
       {
         /* s >> x = t
          * with side condition:
-         * (or
-         *  (= t z)
-         *  (= s t)
-         *  (and
-         *   (= ((_ extract w-1-i 0) t) ((_ extract w-1 i) s))
-         *   (= ((_ extract w-1 w-i) t) z_i)))
-         * for 0 < i < w
-         * where
-         * w = getSize(s) = getSize(t)
-         * and z_i = 0 with getSize(z_i) = i  */
-        NodeBuilder<> nb(nm, OR);
-        nb << nm->mkNode(EQUAL, t, s);
-        for (unsigned i = 1; i < w; ++i)
-        {
-          nb << nm->mkNode(AND,
-              nm->mkNode(EQUAL,
-                bv::utils::mkExtract(t, w - 1 - i, 0),
-                bv::utils::mkExtract(s, w - 1, i)),
-              nm->mkNode(EQUAL,
-                bv::utils::mkExtract(t, w - 1, w - i),
-                bv::utils::mkZero(i)));
-        }
-        nb << t.eqNode(z);
-        scl = nb.constructNode();
+         * (or (= (bvlshr s i) t) ...)
+         * for i in 0..w  */
+        scl = defaultShiftSc(EQUAL, BITVECTOR_LSHR, s, t);
       }
       else
       {
@@ -1681,41 +1669,23 @@ static Node getScBvAshr(bool pol,
       {
         /* x >> s = t
          * with side condition:
-         * (or
-         *  (= s z)
-         *  (and
-         *   (bvult s w)
-         *   (or
-         *    (= ((_ extract 2*w-1 w-1) (bvshl (concat z t) (concat z s))) z)
-         *    (= ((_ extract 2*w-1 w-1)
-         *        (bvshl (concat ones t) (concat z s))) ones)))
-         *  (and (bvuge s w) (or (= t z) (= t ones))))
+         * (and
+         *  (=> (bvult s w) (= (bvashr (bvshl t s) s) t))
+         *  (=> (bvuge s w) (or (= t ones) (= t z)))
+         * )
          * where
          * z = 0 with getSize(z) = w
          * and ones = ~0 with getSize(ones) = w
          * and w = getSize(t) = getSize(s)  */
-        Node zz = bv::utils::mkZero(w+1);
-        Node nn = bv::utils::mkOnes(w+1);
         Node ww = bv::utils::mkConst(w, w);
-
-        Node z_o_t = bv::utils::mkConcat(z, t);
-        Node z_o_s = bv::utils::mkConcat(z, s);
-        Node n_o_t = bv::utils::mkConcat(n, t);
-
-        Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
-        Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s);
-        Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1);
-        Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1);
-
-        Node o1 = s.eqNode(z);
-        Node o2 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_ULT, s, ww),
-            nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn)));
-        Node o3 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_UGE, s, ww),
-            nm->mkNode(OR, t.eqNode(z), t.eqNode(n)));
-
-        scl = nm->mkNode(OR, o1, o2, o3);
+        Node shl = nm->mkNode(BITVECTOR_SHL, t, s);
+        Node ashr = nm->mkNode(BITVECTOR_ASHR, shl, s);
+        Node ult = nm->mkNode(BITVECTOR_ULT, s, ww);
+        Node imp1 = ult.impNode(ashr.eqNode(t));
+        Node to = t.eqNode(n);
+        Node tz = t.eqNode(z);
+        Node imp2 = ult.notNode().impNode(to.orNode(tz));
+        scl = imp1.andNode(imp2);
       }
       else
       {
@@ -1730,48 +1700,9 @@ static Node getScBvAshr(bool pol,
       {
         /* s >> x = t
          * with side condition:
-         * (or
-         *  (and (= ((_ extract w-1 w-1) s) (_ bv0 1)) (= t z))
-         *  (and (= ((_ extract w-1 w-1) s) (_ bv0 1)) (= t ones))
-         *  (= s t)
-         *  (and
-         *   (= ((_ extract w-1-i 0) t) ((_ extract w-1 i) s))
-         *   (or
-         *    (and
-         *     (= ((_ extract w-1 w-1) s) (_ bv0 1))
-         *     (= ((_ extract w-1 w-i) t) (_ bv0 i)))
-         *    (and
-         *     (= ((_ extract w-1 w-1) s) (_ bv1 1))
-         *     (= ((_ extract w-1 w-i) t) ones_i)))))
-         * for 0 < i < w
-         * where
-         * z = 0 and getSize(z) = w
-         * and ones = ~0 and getSize(ones) = w
-         * and ones_i = ~0 and getSize(ones_i) = i
-         * and w = getSize(s) = getSize(t)  */
-        Node msbz = bv::utils::mkExtract(
-            s, w-1, w-1).eqNode(bv::utils::mkZero(1));
-        Node msbn = bv::utils::mkExtract(
-            s, w-1, w-1).eqNode(bv::utils::mkOnes(1));
-        NodeBuilder<> nb(nm, OR);
-        nb << nm->mkNode(EQUAL, t, s);
-        for (unsigned i = 1; i < w; ++i)
-        {
-          Node ext = bv::utils::mkExtract(t, w-1, w-i);
-
-          Node o1 = nm->mkNode(AND, msbz, ext.eqNode(bv::utils::mkZero(i)));
-          Node o2 = nm->mkNode(AND, msbn, ext.eqNode(bv::utils::mkOnes(i)));
-          Node o = nm->mkNode(OR, o1, o2);
-
-          Node e = nm->mkNode(EQUAL,
-                              bv::utils::mkExtract(t, w - 1 - i, 0),
-                              bv::utils::mkExtract(s, w - 1, i));
-
-          nb << nm->mkNode(AND, e, o);
-        }
-        nb << nm->mkNode(AND, msbz, t.eqNode(z));
-        nb << nm->mkNode(AND, msbn, t.eqNode(n));
-        scl = nb.constructNode();
+         * (or (= (bvashr s i) t) ...)
+         * for i in 0..w  */
+        scl = defaultShiftSc(EQUAL, BITVECTOR_ASHR, s, t);
       }
       else
       {
@@ -1998,28 +1929,6 @@ static Node getScBvAshr(bool pol,
   return sc;
 }
 
-static Node naiveShlSc1(Kind litk, Node s, Node t)
-{
-  unsigned w;
-  NodeBuilder<> nb(OR);
-  NodeManager *nm;
-
-  nm = NodeManager::currentNM();
-
-  w = bv::utils::getSize(s);
-  Assert(w == bv::utils::getSize(t));
-
-  nb << nm->mkNode(litk, s, t);
-  for (unsigned i = 1; i <= w; i++)
-  {
-    Node sw = bv::utils::mkConst(w, i);
-    nb << nm->mkNode(litk, nm->mkNode(BITVECTOR_SHL, s, sw), t);
-  }
-  if (nb.getNumChildren() == 1)
-    return nb[0];
-  return nb.constructNode();
-}
-
 static Node getScBvShl(bool pol,
                        Kind litk,
                        Kind k,
@@ -2047,30 +1956,11 @@ static Node getScBvShl(bool pol,
       if (pol)
       {
         /* x << s = t
-         * with side condition:
-         * (s = 0 || ctz(t) >= s)
-         * <->
-         * (or
-         *  (= s z)
-         *  (and
-         *   (bvult s w)
-         *   (= ((_ extract w-1 0) (bvlshr (concat t z) (concat z s))) z))
-         *  (and (bvuge s w) (= t z)))
-         *
-         * where
-         * w = getSize(s) = getSize(t) = getSize(z) and z = 0  */
-        Node shr = nm->mkNode(BITVECTOR_LSHR,
-            bv::utils::mkConcat(t, z),
-            bv::utils::mkConcat(z, s));
-        Node ext = bv::utils::mkExtract(shr, w - 1, 0);
-
-        Node o1 = nm->mkNode(EQUAL, s, z);
-        Node o2 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
-        Node o3 = nm->mkNode(AND,
-            nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
-
-        scl = nm->mkNode(OR, o1, o2, o3);
+         * with side condition (synthesized):
+         * (= (bvshl (bvlshr t s) s) t)  */
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, t, s);
+        Node shl = nm->mkNode(BITVECTOR_SHL, lshr, s);
+        scl = shl.eqNode(t);
       }
       else
       {
@@ -2091,29 +1981,9 @@ static Node getScBvShl(bool pol,
       {
         /* s << x = t
          * with side condition:
-         * (or
-         *  (= t z)
-         *  (= s z)
-         *  (and
-         *   (= ((_ extract w-1 i) t) ((_ extract w-1-i 0)))
-         *   (= ((_ extract i-1 0) t) z_i)))
-         * for 0 < i < w
-         * where
-         * w = getSize(s) = getSize(t)
-         * and z = 0 with getSize(z) = w
-         * and z_i = 0 with getSize(z_i) = i  */
-        NodeBuilder<> nb(nm, OR);
-        nb << nm->mkNode(EQUAL, t, s);
-        for (unsigned i = 1; i < w; ++i)
-        {
-          nb << nm->mkNode(AND,
-              nm->mkNode(EQUAL,
-                bv::utils::mkExtract(t, w-1, i), bv::utils::mkExtract(s, w-1-i, 0)),
-              nm->mkNode(EQUAL,
-                bv::utils::mkExtract(t, i-1, 0), bv::utils::mkZero(i)));
-        }
-        nb << t.eqNode(z);
-        scl = nb.constructNode();
+         * (or (= (bvshl s i) t) ...)
+         * for i in 0..w  */
+        scl = defaultShiftSc(EQUAL, BITVECTOR_SHL, s, t);
       }
       else
       {
@@ -2160,8 +2030,8 @@ static Node getScBvShl(bool pol,
         /* s << x >= t
          * with side condition:
          * (or (bvuge (bvshl s i) t) ...)
-         * for i in 0..w-1  */
-        scl = naiveShlSc1(BITVECTOR_UGE, s, t);
+         * for i in 0..w  */
+        scl = defaultShiftSc(BITVECTOR_UGE, BITVECTOR_SHL, s, t);
       }
     }
   }
@@ -2194,8 +2064,8 @@ static Node getScBvShl(bool pol,
         /* s << x > t
          * with side condition:
          * (or (bvugt (bvshl s i) t) ...)
-         * for i in 0..w-1  */
-        scl = naiveShlSc1(BITVECTOR_UGT, s, t);
+         * for i in 0..w  */
+        scl = defaultShiftSc(BITVECTOR_UGT, BITVECTOR_SHL, s, t);
       }
       else
       {
@@ -2257,8 +2127,8 @@ static Node getScBvShl(bool pol,
         /* s << x >= t
          * with side condition:
          * (or (bvsge (bvshl s i) t) ...)
-         * for i in 0..w-1  */
-        scl = naiveShlSc1(BITVECTOR_SGE, s, t);
+         * for i in 0..w  */
+        scl = defaultShiftSc(BITVECTOR_SGE, BITVECTOR_SHL, s, t);
       }
     }
   }
@@ -2300,8 +2170,8 @@ static Node getScBvShl(bool pol,
         /* s << x > t
          * with side condition:
          * (or (bvsgt (bvshl s i) t) ...)
-         * for i in 0..w-1  */
-        scl = naiveShlSc1(BITVECTOR_SGT, s, t);
+         * for i in 0..w  */
+        scl = defaultShiftSc(BITVECTOR_SGT, BITVECTOR_SHL, s, t);
       }
       else
       {