Fix side condition handling for PLUS, XOR, SIGN_EXTEND for CBQI BV. (#1480)
authorAina Niemetz <aina.niemetz@gmail.com>
Fri, 5 Jan 2018 02:32:48 +0000 (18:32 -0800)
committerMathias Preiner <mathias.preiner@gmail.com>
Fri, 5 Jan 2018 02:32:48 +0000 (18:32 -0800)
src/theory/quantifiers/bv_inverter.cpp

index f33229ae0106e33fd864fd0d57fabf72f7d3d307..2bc93de60ad369171bc887d152a46bfdc91254fa 100644 (file)
@@ -219,7 +219,8 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t)
     {
       /* x < t
        * with side condition:
-       * t != 0  */
+       * (distinct t z)
+       * where z = 0 with getSize(z) = w  */
       Node scl = nm->mkNode(DISTINCT, t, bv::utils::mkZero(w));
       Node scr = nm->mkNode(k, x, t);
       sc = nm->mkNode(IMPLIES, scl, scr);
@@ -238,7 +239,8 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t)
     {
       /* x > t
        * with side condition:
-       * t != ~0  */
+       * (distinct t ones)
+       * where ones = ~0 with getSize(ones) = w  */
       Node scl = nm->mkNode(DISTINCT, t, bv::utils::mkOnes(w));
       Node scr = nm->mkNode(k, x, t);
       sc = nm->mkNode(IMPLIES, scl, scr);
@@ -268,7 +270,8 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
     {
       /* x < t
        * with side condition:
-       * t != 10...0 */
+       * (distinct t min)
+       * where min is the minimum signed value with getSize(min) = w  */
       Node min = bv::utils::mkConst(BitVector(w).setBit(w - 1));
       Node scl = nm->mkNode(DISTINCT, min, t);
       Node scr = nm->mkNode(k, x, t);
@@ -288,7 +291,8 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t)
     {
       /* x > t
        * with side condition:
-       * t != 01...1  */
+       * (distinct t max)
+       * where max is the maximum signed value with getSize(max) = w  */
       BitVector bv = BitVector(w).setBit(w - 1);
       Node max = bv::utils::mkConst(~bv);
       Node scl = nm->mkNode(DISTINCT, t, max);
@@ -316,8 +320,8 @@ static Node getScBvMult(bool pol,
 {
   Assert(k == BITVECTOR_MULT);
   Assert (litk == EQUAL
-      || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
-      || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
+       || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
+       || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
 
   NodeManager* nm = NodeManager::currentNM();
   Node scl;
@@ -341,7 +345,8 @@ static Node getScBvMult(bool pol,
        *   (= t z)
        *   (and
        *     (bvuge (bvand t (bvneg t)) (bvand s (bvneg s)))
-       *     (distinct s z)))  */
+       *     (distinct s z)))
+       * where z = 0 with getSize(z) = w  */
       Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s);
       scl = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_AND, o, t), t);
     }
@@ -406,7 +411,8 @@ static Node getScBvMult(bool pol,
     {
       /* x * s >= t
        * with side condition (synthesized):
-       * (bvsge (bvand (bvor (bvneg s) s) max) t))  */
+       * (bvsge (bvand (bvor (bvneg s) s) max) t))
+       * where max is the maximum signed value with getSize(max) = w  */
       BitVector bv = BitVector(w).setBit(w - 1);
       Node max = bv::utils::mkConst(~bv);
       Node o = nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NEG, s), s);
@@ -455,8 +461,8 @@ static Node getScBvUrem(bool pol,
 {
   Assert(k == BITVECTOR_UREM_TOTAL);
   Assert (litk == EQUAL
-      || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
-      || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
+       || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
+       || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
 
   NodeManager* nm = NodeManager::currentNM();
   Node scl;
@@ -677,7 +683,7 @@ static Node getScBvUrem(bool pol,
          *     (=> (bvsle s z) (distinct t max)))
          *   (or (distinct t z) (distinct s (_ bv1 w))))
          * where z = 0 with getSize(z) = w
-         * and max is the maximum signed value  */
+         * and max is the maximum signed value with getSize(max) = w  */
         BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
         BitVector bv_max = BitVector(1).concat(bv_ones);
         Node max = bv::utils::mkConst(bv_max);
@@ -726,7 +732,7 @@ static Node getScBvUrem(bool pol,
         /* s % x < t
          * with side condition (synthesized):
          * (or (bvult t min) (bvsge t s))
-         * where min is the minimum signed value  */
+         * where min is the minimum signed value with getSize(min) = w  */
         BitVector bv_min = BitVector(w).setBit(w - 1);
         Node min = bv::utils::mkConst(bv_min);
         Node o1 = nm->mkNode(BITVECTOR_ULT, t, min);
@@ -791,7 +797,8 @@ static Node getScBvUdiv(bool pol,
         /* x udiv s != t
          * with side condition:
          * (or (distinct s z) (distinct t ones))
-         * where z = 0 with getSize(z) = w and ones = ~0   */
+         * where z = 0 with getSize(z) = w
+         * and ones = ~0 with getSize(ones) = w  */
         Node ones = bv::utils::mkOnes(w);
         scl = nm->mkNode(OR, s.eqNode(z).notNode(), t.eqNode(ones).notNode());
       }
@@ -882,7 +889,7 @@ static Node getScBvUdiv(bool pol,
         /* x udiv s > t
          * with side condition:
          * (bvugt (bvudiv ones s) t)
-         * with ones = ~0   */
+         * where ones = ~0 with getSize(ones) = w  */
         Node ones = bv::utils::mkOnes(w);
         Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s);
         scl = nm->mkNode(BITVECTOR_UGT, div, t);
@@ -904,7 +911,7 @@ static Node getScBvUdiv(bool pol,
         /* s udiv x > t
          * with side condition (synthesized):
          * (bvult t ones)
-         * with ones = ~0   */
+         * where ones = ~0 with getSize(ones) = w  */
         Node ones = bv::utils::mkOnes(w);
         scl = nm->mkNode(BITVECTOR_ULT, t, ones);
       }
@@ -929,7 +936,7 @@ static Node getScBvUdiv(bool pol,
          * with side condition:
          * (=> (bvsle t z) (bvslt (bvudiv min s) t))
          * where z = 0 with getSize(z) = w
-         * and min is the minimum signed value */
+         * and min is the minimum signed value with getSize(min) = w  */
         BitVector bv_min = BitVector(w).setBit(w - 1);
         Node min = bv::utils::mkConst(bv_min);
         Node sle = nm->mkNode(BITVECTOR_SLE, t, z);
@@ -944,7 +951,8 @@ static Node getScBvUdiv(bool pol,
          * (or
          *   (bvsge (bvudiv ones s) t)
          *   (bvsge (bvudiv max s) t))
-         * with ones = ~0 and max the maximum signed value */
+         * where ones = ~0 with getSize(ones) = w
+         * and max is the maximum signed value with getSize(max) = w  */
         BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
         BitVector bv_max = BitVector(1).concat(bv_ones);
         Node max = bv::utils::mkConst(bv_max);
@@ -997,7 +1005,8 @@ static Node getScBvUdiv(bool pol,
          * (or
          *   (bvsgt (bvudiv ones s) t)
          *   (bvsgt (bvudiv max s) t))
-         * with ones = ~0 and max the maximum signed value */
+         * where ones = ~0 with getSize(ones) = w
+         * and max is the maximum signed value with getSize(max) = w  */
         BitVector bv_ones = utils::mkBitVectorOnes(w - 1);
         BitVector bv_max = BitVector(1).concat(bv_ones);
         Node max = bv::utils::mkConst(bv_max);
@@ -1015,7 +1024,8 @@ static Node getScBvUdiv(bool pol,
          * (or
          *   (= (bvudiv (bvmul s t) s) t)                ; eq, synthesized
          *   (=> (bvsle t z) (bvslt (bvudiv min s) t)))  ; slt
-         * where z = 0 with getSize(z) = w  */
+         * where z = 0 with getSize(z) = w
+         * and min is the minimum signed value with getSize(min) = w  */
         Node mul = nm->mkNode(BITVECTOR_MULT, s, t);
         Node div1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, s);
         Node o1 = nm->mkNode(EQUAL, div1, t);
@@ -1053,7 +1063,7 @@ static Node getScBvUdiv(bool pol,
          * (not (and (bvslt t (bvnot #x0)) (bvslt t s)))
          * <->
          * (or (bvsge t ones) (bvsge t s))
-         * with ones = ~0  */
+         * where ones = ~0 with getSize(ones) = w  */
         Node ones = bv::utils::mkOnes(w);
         Node sge1 = nm->mkNode(BITVECTOR_SGE, t, ones);
         Node sge2 = nm->mkNode(BITVECTOR_SGE, t, s);
@@ -1406,7 +1416,8 @@ static Node getScBvLshr(bool pol,
       {
         /* x >> s >= t
          * with side condition:
-         * (=> (not (= s z)) (bvsge (bvlshr ones s) t)) */
+         * (=> (not (= s z)) (bvsge (bvlshr ones s) t))
+         * where ones = ~0 with getSize(ones) = w  */
         Node ones = bv::utils::mkOnes(w);
         Node lshr = nm->mkNode(BITVECTOR_LSHR, ones, s);
         Node nz = s.eqNode(z).notNode();
@@ -1450,9 +1461,9 @@ static Node getScBvLshr(bool pol,
       {
         /* x >> s > t
          * with side condition (synthesized):
-         * (bvslt t (bvlshr (bvshl max_val s) s))
+         * (bvslt t (bvlshr (bvshl max s) s))
          * where
-         * max_val is the signed maximum value */
+         * max is the signed maximum value with getSize(max) = w  */
         BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
         BitVector bv_max_val = BitVector(1).concat(bv_ones);
         Node max = bv::utils::mkConst(bv_max_val);
@@ -1488,9 +1499,9 @@ static Node getScBvLshr(bool pol,
       {
         /* s >> x <= t
          * with side condition (synthesized):
-         * (or (bvult t min_val) (bvsge t s))
+         * (or (bvult t min) (bvsge t s))
          * where
-         * min_val is the signed minimum value */
+         * min is the minimum signed value with getSize(min) = w  */
         BitVector bv_min_val = BitVector(w).setBit(w - 1);
         Node min = bv::utils::mkConst(bv_min_val);
         Node ult = nm->mkNode(BITVECTOR_ULT, t, min);
@@ -1711,9 +1722,9 @@ static Node getScBvAshr(bool pol,
       {
         /* s >> x <= t
          * with side condition (synthesized):
-         * (or (bvult s min_val) (bvuge t s))
+         * (or (bvult s min) (bvuge t s))
          * where
-         * min_val is the signed minimum value */
+         * min is the minimum signed value with getSize(min) = w  */
         BitVector bv_min_val = BitVector(w).setBit(w - 1);
         Node min = bv::utils::mkConst(bv_min_val);
         Node ult = nm->mkNode(BITVECTOR_ULT, s, min);
@@ -1730,9 +1741,9 @@ static Node getScBvAshr(bool pol,
       {
         /* x >> s < t
          * with side condition:
-         * (bvslt (bvashr min_val s) t)
+         * (bvslt (bvashr min s) t)
          * where
-         * min_val is the signed minimum value */
+         * min is the minimum signed value with getSize(min) = w  */
         BitVector bv_min_val = BitVector(w).setBit(w - 1);
         Node min = bv::utils::mkConst(bv_min_val);
         scl = nm->mkNode(BITVECTOR_SLT, nm->mkNode(BITVECTOR_ASHR, min, s), t);
@@ -1741,9 +1752,9 @@ static Node getScBvAshr(bool pol,
       {
         /* x >> s >= t
          * with side condition:
-         * (bvsge (bvlshr max_val s) t)
+         * (bvsge (bvlshr max s) t)
          * where
-         * max_val is the signed maximum value */
+         * max is the signed maximum value with getSize(max) = w */
         BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
         BitVector bv_max_val = BitVector(1).concat(bv_ones);
         Node max = bv::utils::mkConst(bv_max_val);
@@ -1785,18 +1796,18 @@ static Node getScBvAshr(bool pol,
       {
         /* x >> s > t
          * with side condition (synthesized):
-         * (bvslt t (bvlshr max_val s)))
+         * (bvslt t (bvlshr max s)))
          * where
-         * max_val is the signed maximum value */
+         * max is the signed maximum value with getSize(max) = w  */
         scl = nm->mkNode(BITVECTOR_SLT, t, lshr);
       }
       else
       {
         /* x >> s <= t
          * with side condition (synthesized):
-         * (bvsge t (bvnot (bvlshr max_value s)))
+         * (bvsge t (bvnot (bvlshr max s)))
          * where
-         * max_val is the signed maximum value */
+         * max is the signed maximum value with getSize(max) = w */
         scl = nm->mkNode(BITVECTOR_SGE, t, nm->mkNode(BITVECTOR_NOT, lshr));
       }
     }
@@ -1806,9 +1817,9 @@ static Node getScBvAshr(bool pol,
       {
         /* s >> x > t
          * with side condition (synthesized):
-         * (and (bvslt t (bvand s max_val)) (bvslt t (bvor s max_val)))
+         * (and (bvslt t (bvand s max)) (bvslt t (bvor s max)))
          * where
-         * max_val is the signed maximum value */
+         * max is the signed maximum value with getSize(max) = w  */
         Node sam = nm->mkNode(BITVECTOR_AND, s, max);
         Node som = nm->mkNode(BITVECTOR_OR, s, max);
         Node slta = nm->mkNode(BITVECTOR_SLT, t, sam);
@@ -2217,7 +2228,7 @@ Node BvInverter::solveBvLit(Node sv,
     {
       t = nm->mkNode(k, t);
     }
-    else if (k == BITVECTOR_CONCAT && litk == EQUAL)
+    else if (k == BITVECTOR_CONCAT)
     {
       /* x = t[upper:lower]
        * where
@@ -2235,7 +2246,7 @@ Node BvInverter::solveBvLit(Node sv,
       }
       t = bv::utils::mkExtract(t, upper, lower);
     }
-    else if (k == BITVECTOR_SIGN_EXTEND && litk == EQUAL)
+    else if (k == BITVECTOR_SIGN_EXTEND)
     {
       t = bv::utils::mkExtract(t, bv::utils::getSize(sv_t[index]) - 1, 0);
     }
@@ -2251,11 +2262,11 @@ Node BvInverter::solveBvLit(Node sv,
       Node s = nchildren == 2 ? sv_t[1 - index] : dropChild(sv_t, index);
       /* Note: All n-ary kinds except for CONCAT (i.e., AND, OR, MULT, PLUS)
        *       are commutative (no case split based on index). */
-      if (k == BITVECTOR_PLUS && litk == EQUAL)
+      if (k == BITVECTOR_PLUS)
       {
         t = nm->mkNode(BITVECTOR_SUB, t, s);
       }
-      else if (k == BITVECTOR_XOR && litk == EQUAL)
+      else if (k == BITVECTOR_XOR)
       {
         t = nm->mkNode(BITVECTOR_XOR, t, s);
       }