Add special {SGE,SGT,NE}_UDIV1 side conditions for BV of size 1. (#1483)
authorMathias Preiner <mathias.preiner@gmail.com>
Sat, 6 Jan 2018 06:50:21 +0000 (22:50 -0800)
committerGitHub <noreply@github.com>
Sat, 6 Jan 2018 06:50:21 +0000 (22:50 -0800)
This commit further fixes some other issues with bit-vectors of size 1.

src/theory/bv/theory_bv_utils.h
src/theory/quantifiers/bv_inverter.cpp

index ed772b7c4da4509a04d738b122db9ce394da3e6a..a11436c9e8d15b47be2c49fade78171e6f08ce1f 100644 (file)
@@ -195,6 +195,18 @@ inline BitVector mkBitVectorOnes(unsigned size) {
   return BitVector(1, Integer(1)).signExtend(size - 1); 
 }
 
+inline BitVector mkBitVectorMinSigned(unsigned size)
+{
+  Assert(size > 0);
+  return BitVector(size).setBit(size - 1);
+}
+
+inline BitVector mkBitVectorMaxSigned(unsigned size)
+{
+  Assert(size > 0);
+  return ~mkBitVectorMinSigned(size);
+}
+
 inline Node mkOnes(unsigned size) {
   BitVector val = mkBitVectorOnes(size); 
   return NodeManager::currentNM()->mkConst<BitVector>(val); 
index 7407051be651b819fad416e1d4324c3a9d45b6c8..3ad99999cce685911299458197540d6b7cf996a1 100644 (file)
@@ -277,7 +277,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
        * (distinct t min)
        * where
        * min is the minimum signed value with getSize(min) = w  */
-      Node min = bv::utils::mkConst(BitVector(w).setBit(w - 1));
+      Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
       Node scl = nm->mkNode(DISTINCT, min, t);
       Node scr = nm->mkNode(k, x, t);
       sc = nm->mkNode(IMPLIES, scl, scr);
@@ -300,8 +300,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
        * (distinct t max)
        * where
        * max is the signed maximum value with getSize(max) = w  */
-      BitVector bv = BitVector(w).setBit(w - 1);
-      Node max = bv::utils::mkConst(~bv);
+      Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
       Node scl = nm->mkNode(DISTINCT, t, max);
       Node scr = nm->mkNode(k, x, t);
       sc = nm->mkNode(IMPLIES, scl, scr);
@@ -425,8 +424,7 @@ static Node getScBvMult(bool pol,
        * (bvsge (bvand (bvor (bvneg s) s) max) t)
        * where
        * max is the signed maximum value with getSize(max) = w  */
-      BitVector bv = BitVector(w).setBit(w - 1);
-      Node max = bv::utils::mkConst(~bv);
+      Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
       Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s);
       Node a = nm->mkNode(BITVECTOR_AND, o, max);
       scl = nm->mkNode(BITVECTOR_SGE, a, t);
@@ -706,9 +704,7 @@ static Node getScBvUrem(bool pol,
          * where
          * z = 0 with getSize(z) = w
          * and max is the maximum signed value with getSize(max) = w  */
-        BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node nt = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, s));
         Node i1 = nm->mkNode(IMPLIES,
             nm->mkNode(BITVECTOR_SGT, s, z), nm->mkNode(BITVECTOR_SLT, t, nt));
@@ -759,8 +755,7 @@ static Node getScBvUrem(bool pol,
          * (or (bvult t min) (bvsge t s))
          * where
          * min is the minimum signed value with getSize(min) = w  */
-        BitVector bv_min = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node o1 = nm->mkNode(BITVECTOR_ULT, t, min);
         Node o2 = nm->mkNode(BITVECTOR_SGE, t, s);
         scl = nm->mkNode(OR, o1, o2);
@@ -859,8 +854,22 @@ static Node getScBvUdiv(bool pol,
       else
       {
         /* s udiv x != t
-         * true (no side condition)  */
-        scl = nm->mkConst<bool>(true);
+         * with side condition (w > 1):
+         * true (no side condition)
+         *
+         * with side condition (w == 1):
+         * (= (bvand s t) z)
+         *
+         * where
+         * z = 0 with getSize(z) = w  */
+        if (w > 1)
+        {
+          scl = nm->mkConst<bool>(true);
+        }
+        else
+        {
+          scl = nm->mkNode(BITVECTOR_AND, s, t).eqNode(z);
+        }
       }
     }
   }
@@ -972,8 +981,7 @@ static Node getScBvUdiv(bool pol,
          * where
          * z = 0 with getSize(z) = w
          * and min is the minimum signed value with getSize(min) = w  */
-        BitVector bv_min = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node sle = nm->mkNode(BITVECTOR_SLE, t, z);
         Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, min, s);
         Node slt = nm->mkNode(BITVECTOR_SLT, div, t);
@@ -989,9 +997,7 @@ static Node getScBvUdiv(bool pol,
          * where
          * ones = ~0 with getSize(ones) = w
          * and max is the maximum signed value with getSize(max) = w  */
-        BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node ones = bv::utils::mkOnes(w);
         Node udiv1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s);
         Node udiv2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, max, s);
@@ -1016,19 +1022,32 @@ static Node getScBvUdiv(bool pol,
       else
       {
         /* s udiv x >= t
-         * with side condition:
+         * with side condition (w > 1):
          * (and
          *   (=> (bvsge s z) (bvsge s t))
-         *   (=> (bvslt s z) (bvsge (bvudiv s (_ bv2 w)) t)))
+         *   (=> (bvslt s z) (bvsge (bvlshr s (_ bv1 w)) t)))
+         *
+         * with side condition (w == 1):
+         * (bvsge s t)
+         *
          * where
          * z = 0 with getSize(z) = w  */
-        Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL,
-            s, bv::utils::mkConst(w, 2));
-        Node i1 = nm->mkNode(IMPLIES,
-            nm->mkNode(BITVECTOR_SGE, s, z), nm->mkNode(BITVECTOR_SGE, s, t));
-        Node i2 = nm->mkNode(IMPLIES,
-            nm->mkNode(BITVECTOR_SLT, s, z), nm->mkNode(BITVECTOR_SGE, div, t));
-        scl = nm->mkNode(AND, i1, i2);
+
+        if (w > 1)
+        {
+          Node div = nm->mkNode(BITVECTOR_LSHR, s, bv::utils::mkConst(w, 1));
+          Node i1 = nm->mkNode(IMPLIES,
+                               nm->mkNode(BITVECTOR_SGE, s, z),
+                               nm->mkNode(BITVECTOR_SGE, s, t));
+          Node i2 = nm->mkNode(IMPLIES,
+                               nm->mkNode(BITVECTOR_SLT, s, z),
+                               nm->mkNode(BITVECTOR_SGE, div, t));
+          scl = nm->mkNode(AND, i1, i2);
+        }
+        else
+        {
+          scl = nm->mkNode(BITVECTOR_SGE, s, t);
+        }
       }
     }
   }
@@ -1046,9 +1065,7 @@ static Node getScBvUdiv(bool pol,
          * where
          * ones = ~0 with getSize(ones) = w
          * and max is the maximum signed value with getSize(max) = w  */
-        BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node ones = bv::utils::mkOnes(w);
         Node div1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s);
         Node sgt1 = nm->mkNode(BITVECTOR_SGT, div1, t);
@@ -1069,8 +1086,7 @@ static Node getScBvUdiv(bool pol,
         Node mul = nm->mkNode(BITVECTOR_MULT, s, t);
         Node div1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, s);
         Node o1 = nm->mkNode(EQUAL, div1, t);
-        BitVector bv_min = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node sle = nm->mkNode(BITVECTOR_SLE, t, z);
         Node div2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, min, s);
         Node slt = nm->mkNode(BITVECTOR_SLT, div2, t);
@@ -1083,19 +1099,31 @@ static Node getScBvUdiv(bool pol,
       if (pol)
       {
         /* s udiv x > t
-         * with side condition:
+         * with side condition (w > 1):
          * (and
          *   (=> (bvsge s z) (bvsgt s t))
-         *   (=> (bvslt s z) (bvsgt (bvudiv s (_ bv2 w)) t)))
+         *   (=> (bvslt s z) (bvsgt (bvlshr s (_ bv1 w)) t)))
+         *
+         * with side condition (w == 1):
+         * (bvsgt s t)
+         *
          * where
          * z = 0 with getSize(z) = w  */
-        Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL,
-            s, bv::utils::mkConst(w, 2));
-        Node i1 = nm->mkNode(IMPLIES,
-            nm->mkNode(BITVECTOR_SGE, s, z), nm->mkNode(BITVECTOR_SGT, s, t));
-        Node i2 = nm->mkNode(IMPLIES,
-            nm->mkNode(BITVECTOR_SLT, s, z), nm->mkNode(BITVECTOR_SGT, div, t));
-        scl = nm->mkNode(AND, i1, i2);
+        if (w > 1)
+        {
+          Node div = nm->mkNode(BITVECTOR_LSHR, s, bv::utils::mkConst(w, 1));
+          Node i1 = nm->mkNode(IMPLIES,
+                               nm->mkNode(BITVECTOR_SGE, s, z),
+                               nm->mkNode(BITVECTOR_SGT, s, t));
+          Node i2 = nm->mkNode(IMPLIES,
+                               nm->mkNode(BITVECTOR_SLT, s, z),
+                               nm->mkNode(BITVECTOR_SGT, div, t));
+          scl = nm->mkNode(AND, i1, i2);
+        }
+        else
+        {
+          scl = nm->mkNode(BITVECTOR_SGT, s, t);
+        }
       }
       else
       {
@@ -1310,9 +1338,7 @@ static Node getScBvAndOr(bool pol,
        * (bvslt t (bvor s max))
        * where
        * max is the signed maximum value with getSize(max) = w  */
-      BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-      BitVector bv_max_val = BitVector(1).concat(bv_ones);
-      Node max = bv::utils::mkConst(bv_max_val);
+      Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
       scl = nm->mkNode(BITVECTOR_SLT, t, nm->mkNode(k, s, max));
     }
     else
@@ -1324,8 +1350,7 @@ static Node getScBvAndOr(bool pol,
          * (bvuge s (bvand t min))
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         scl = nm->mkNode(BITVECTOR_UGE, s, nm->mkNode(BITVECTOR_AND, t, min));
       }
       else
@@ -1335,8 +1360,7 @@ static Node getScBvAndOr(bool pol,
          * (bvsge t (bvor s min))
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         scl = nm->mkNode(BITVECTOR_SGE, t, nm->mkNode(BITVECTOR_OR, s, min));
       }
     }
@@ -1587,9 +1611,7 @@ static Node getScBvLshr(bool pol,
          * (bvslt t (bvlshr (bvshl max s) s))
          * where
          * max is the signed maximum value with getSize(max) = w  */
-        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max_val = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max_val);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node shl = nm->mkNode(BITVECTOR_SHL, max, s);
         Node lshr = nm->mkNode(BITVECTOR_LSHR, shl, s);
         scl = nm->mkNode(BITVECTOR_SLT, t, lshr);
@@ -1627,8 +1649,7 @@ static Node getScBvLshr(bool pol,
          * (or (bvult t min) (bvsge t s))
          * where
          * min is the minimum signed value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node ult = nm->mkNode(BITVECTOR_ULT, t, min);
         Node sge = nm->mkNode(BITVECTOR_SGE, t, s);
         scl = ult.orNode(sge);
@@ -1806,8 +1827,7 @@ static Node getScBvAshr(bool pol,
          * (or (bvult s min) (bvuge t s))
          * where
          * min is the minimum signed value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node ult = nm->mkNode(BITVECTOR_ULT, s, min);
         Node uge = nm->mkNode(BITVECTOR_UGE, t, s);
         scl = ult.orNode(uge);
@@ -1825,8 +1845,7 @@ static Node getScBvAshr(bool pol,
          * (bvslt (bvashr min s) t)
          * where
          * min is the minimum signed value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         scl = nm->mkNode(BITVECTOR_SLT, nm->mkNode(BITVECTOR_ASHR, min, s), t);
       }
       else
@@ -1836,9 +1855,7 @@ static Node getScBvAshr(bool pol,
          * (bvsge (bvlshr max s) t)
          * where
          * max is the signed maximum value with getSize(max) = w  */
-        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max_val = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max_val);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         scl = nm->mkNode(BITVECTOR_SGE, nm->mkNode(BITVECTOR_LSHR, max, s), t);
       }
     }
@@ -1869,9 +1886,7 @@ static Node getScBvAshr(bool pol,
   else
   {
     Assert(litk == BITVECTOR_SGT);
-    BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-    BitVector bv_max_val = BitVector(1).concat(bv_ones);
-    Node max = bv::utils::mkConst(bv_max_val);
+    Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
     if (idx == 0)
     {
       Node lshr = nm->mkNode(BITVECTOR_LSHR, max, s);
@@ -2087,8 +2102,7 @@ static Node getScBvShl(bool pol,
          * (bvslt (bvshl (bvlshr min s) s) t)
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node lshr = nm->mkNode(BITVECTOR_LSHR, min, s);
         Node shl = nm->mkNode(BITVECTOR_SHL, lshr, s);
         scl = nm->mkNode(BITVECTOR_SLT, shl, t);
@@ -2100,9 +2114,7 @@ static Node getScBvShl(bool pol,
          * (bvsge (bvand (bvshl max s) max) t)
          * where
          * max is the signed maximum value with getSize(max) = w  */
-        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max_val = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max_val);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node shl = nm->mkNode(BITVECTOR_SHL, max, s);
         scl = nm->mkNode(BITVECTOR_SGE, nm->mkNode(BITVECTOR_AND, shl, max), t);
       }
@@ -2116,8 +2128,7 @@ static Node getScBvShl(bool pol,
          * (bvult (bvshl min s) (bvadd t min))
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node shl = nm->mkNode(BITVECTOR_SHL, min, s);
         Node add = nm->mkNode(BITVECTOR_PLUS, t, min);
         scl = nm->mkNode(BITVECTOR_ULT, shl, add);
@@ -2144,9 +2155,7 @@ static Node getScBvShl(bool pol,
          * (bvslt t (bvand (bvshl max s) max))
          * where
          * max is the signed maximum value with getSize(max) = w  */
-        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
-        BitVector bv_max_val = BitVector(1).concat(bv_ones);
-        Node max = bv::utils::mkConst(bv_max_val);
+        Node max = bv::utils::mkConst(bv::utils::mkBitVectorMaxSigned(w));
         Node shl = nm->mkNode(BITVECTOR_SHL, max, s);
         scl = nm->mkNode(BITVECTOR_SLT, t, nm->mkNode(BITVECTOR_AND, shl, max));
       }
@@ -2157,8 +2166,7 @@ static Node getScBvShl(bool pol,
          * (bvult (bvlshr t (bvlshr t s)) min)
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         Node ts = nm->mkNode(BITVECTOR_LSHR, t, s);
         scl = nm->mkNode(BITVECTOR_ULT, nm->mkNode(BITVECTOR_LSHR, t, ts), min);
       }
@@ -2180,8 +2188,7 @@ static Node getScBvShl(bool pol,
          * (bvult (bvlshr t s) min)
          * where
          * min is the signed minimum value with getSize(min) = w  */
-        BitVector bv_min_val = BitVector(w).setBit(w - 1);
-        Node min = bv::utils::mkConst(bv_min_val);
+        Node min = bv::utils::mkConst(bv::utils::mkBitVectorMinSigned(w));
         scl = nm->mkNode(BITVECTOR_ULT, nm->mkNode(BITVECTOR_LSHR, t, s), min);
       }
     }