Add side conditions for inequalities of ASHR. (#1461)
authorMathias Preiner <mathias.preiner@gmail.com>
Wed, 3 Jan 2018 07:57:43 +0000 (23:57 -0800)
committerGitHub <noreply@github.com>
Wed, 3 Jan 2018 07:57:43 +0000 (23:57 -0800)
src/theory/quantifiers/bv_inverter.cpp

index c794c693c85aa11004419e9400ac91eba41a879a..30dd2a02a3fbf0d96688bc28922ed09580528772 100644 (file)
@@ -881,7 +881,7 @@ static Node getScBvAndOr(bool pol,
           || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
 
   NodeManager* nm = NodeManager::currentNM();
-  Node scl, scr;
+  Node scl;
 
   if (litk == EQUAL)
   {
@@ -1006,8 +1006,7 @@ static Node getScBvAndOr(bool pol,
   {
     return Node::null();
   }
-
-  scr = nm->mkNode(litk, nm->mkNode(k, x, s), t);
+  Node scr = nm->mkNode(litk, nm->mkNode(k, x, s), t);
   Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode());
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
@@ -1222,9 +1221,11 @@ static Node getScBvAshr(bool pol,
                         Node t)
 {
   Assert(k == BITVECTOR_ASHR);
+  Assert(litk == EQUAL || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
+         || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
 
   NodeManager* nm = NodeManager::currentNM();
-  Node scl, scr;
+  Node scl;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
   Node z = bv::utils::mkZero(w);
@@ -1270,13 +1271,12 @@ static Node getScBvAshr(bool pol,
             nm->mkNode(OR, t.eqNode(z), t.eqNode(n)));
 
         scl = nm->mkNode(OR, o1, o2, o3);
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
       }
       else
       {
         /* x >> s != t
-         * no side condition */
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, x, s), t);
+         * true (no side condition) */
+        scl = nm->mkConst<bool>(true);
       }
     }
     else
@@ -1314,32 +1314,229 @@ static Node getScBvAshr(bool pol,
           Node o = nm->mkNode(OR, o1, o2);
 
           Node e = nm->mkNode(EQUAL,
-              bv::utils::mkExtract(t, w-1-i, 0), bv::utils::mkExtract(s, w-1, i));
+                              bv::utils::mkExtract(t, w - 1 - i, 0),
+                              bv::utils::mkExtract(s, w - 1, i));
 
           nb << nm->mkNode(AND, e, o);
         }
         nb << nm->mkNode(AND, msbz, t.eqNode(z));
         nb << nm->mkNode(AND, msbn, t.eqNode(n));
         scl = nb.constructNode();
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
       }
       else
       {
         /* s >> x != t
          * with side condition:
-         * (t != 0 || s != 0) && (t != ~0 || s != ~0)  */
+         * (and
+         *  (or (not (= t z)) (not (= s z)))
+         *  (or (not (= t (bvnot z)) (not (= s (bvnot z))))))
+         */
         scl = nm->mkNode(AND,
             nm->mkNode(OR, t.eqNode(z).notNode(), s.eqNode(z).notNode()),
             nm->mkNode(OR, t.eqNode(n).notNode(), s.eqNode(n).notNode()));
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, s, x), t);
+      }
+    }
+  }
+  else if (litk == BITVECTOR_ULT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x >> s < t
+         * with side condition (synthesized):
+         * (not (= t z)) */
+        scl = t.eqNode(z).notNode();
+      }
+      else
+      {
+        /* x >> s >= t
+         * with side condition (synthesized):
+         * true (no side condition) */
+        scl = nm->mkConst<bool>(true);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x < t
+         * with side condition (synthesized):
+         * (and (not (and (not (bvult s t)) (bvslt s z))) (not (= t z))) */
+        Node st = nm->mkNode(BITVECTOR_UGE, s, t);
+        Node sz = nm->mkNode(BITVECTOR_SLT, s, z);
+        Node tz = t.eqNode(z).notNode();
+        scl = st.andNode(sz).notNode().andNode(tz);
+      }
+      else
+      {
+        /* s >> x < t
+         * with side condition (synthesized):
+         * (not (and (bvult s (bvnot s)) (bvult s t))) */
+        Node ss = nm->mkNode(BITVECTOR_ULT, s, nm->mkNode(BITVECTOR_NOT, s));
+        Node st = nm->mkNode(BITVECTOR_ULT, s, t);
+        scl = ss.andNode(st).notNode();
+      }
+    }
+  }
+  else if (litk == BITVECTOR_UGT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x >> s > t
+         * with side condition (synthesized):
+         * (bvult t (bvnot #x0))
+         */
+        scl = nm->mkNode(BITVECTOR_ULT, t, bv::utils::mkOnes(w));
+      }
+      else
+      {
+        /* x >> s <= t
+         * with side condition (synthesized):
+         * true (no side condition)
+         */
+        scl = nm->mkConst<bool>(true);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x > t
+         * with side condition (synthesized):
+         * (or (bvslt s (bvlshr s (bvnot t))) (bvult t s))
+         */
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, s, nm->mkNode(BITVECTOR_NOT, t));
+        Node ts = nm->mkNode(BITVECTOR_ULT, t, s);
+        Node slt = nm->mkNode(BITVECTOR_SLT, s, lshr);
+        scl = slt.orNode(ts);
+      }
+      else
+      {
+        /* s >> x <= t
+         * with side condition (synthesized):
+         * (or (bvult s min_val) (bvuge t s))
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        Node ult = nm->mkNode(BITVECTOR_ULT, s, min);
+        Node uge = nm->mkNode(BITVECTOR_UGE, t, s);
+        scl = ult.orNode(uge);
+      }
+    }
+  }
+  else if (litk == BITVECTOR_SLT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x >> s < t
+         * with side condition:
+         * (bvslt (bvashr min_val s) t)
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        scl = nm->mkNode(BITVECTOR_SLT, nm->mkNode(BITVECTOR_ASHR, min, s), t);
+      }
+      else
+      {
+        /* x >> s >= t
+         * with side condition:
+         * (bvsge (bvlshr max_val s) t)
+         * where
+         * max_val is the signed maximum value */
+        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
+        BitVector bv_max_val = BitVector(1).concat(bv_ones);
+        Node max = bv::utils::mkConst(bv_max_val);
+        scl = nm->mkNode(BITVECTOR_SGE, nm->mkNode(BITVECTOR_LSHR, max, s), t);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x < t
+         * with side condition (synthesized):
+         * (or (bvslt s t) (bvslt z t)) */
+        Node st = nm->mkNode(BITVECTOR_SLT, s, t);
+        Node zt = nm->mkNode(BITVECTOR_SLT, z, t);
+        scl = st.orNode(zt);
+      }
+      else
+      {
+        /* s >> x >= t
+         * with side condition (synthesized):
+         * (not (and (bvult t (bvnot t)) (bvslt s t))) */
+        Node tt = nm->mkNode(BITVECTOR_ULT, t, nm->mkNode(BITVECTOR_NOT, t));
+        Node st = nm->mkNode(BITVECTOR_SLT, s, t);
+        scl = tt.andNode(st).notNode();
       }
     }
   }
   else
   {
-    return Node::null();
+    Assert(litk == BITVECTOR_SGT);
+    BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
+    BitVector bv_max_val = BitVector(1).concat(bv_ones);
+    Node max = bv::utils::mkConst(bv_max_val);
+    if (idx == 0)
+    {
+      Node lshr = nm->mkNode(BITVECTOR_LSHR, max, s);
+      if (pol)
+      {
+        /* x >> s > t
+         * with side condition (synthesized):
+         * (bvslt t (bvlshr max_val s)))
+         * where
+         * max_val is the signed maximum value */
+        scl = nm->mkNode(BITVECTOR_SLT, t, lshr);
+      }
+      else
+      {
+        /* x >> s <= t
+         * with side condition (synthesized):
+         * (bvsge t (bvnot (bvlshr max_value s)))
+         * where
+         * max_val is the signed maximum value */
+        scl = nm->mkNode(BITVECTOR_SGE, t, nm->mkNode(BITVECTOR_NOT, lshr));
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x > t
+         * with side condition (synthesized):
+         * (and (bvslt t (bvand s max_val)) (bvslt t (bvor s max_val)))
+         * where
+         * max_val is the signed maximum value */
+        Node sam = nm->mkNode(BITVECTOR_AND, s, max);
+        Node som = nm->mkNode(BITVECTOR_OR, s, max);
+        Node slta = nm->mkNode(BITVECTOR_SLT, t, sam);
+        Node slto = nm->mkNode(BITVECTOR_SLT, t, som);
+        scl = slta.andNode(slto);
+      }
+      else
+      {
+        /* s >> x <= t
+         * with side condition (synthesized):
+         * (not (and (bvslt t z) (bvslt t s)))
+         * (or (bvsge t z) (bvsge t s))
+         */
+        Node tz = nm->mkNode(BITVECTOR_SGE, t, z);
+        Node ts = nm->mkNode(BITVECTOR_SGE, t, s);
+        scl = tz.orNode(ts);
+      }
+    }
   }
-  Node sc = scl.isNull() ? scr : nm->mkNode(IMPLIES, scl, scr);
+  Node scr =
+      nm->mkNode(litk, idx == 0 ? nm->mkNode(k, x, s) : nm->mkNode(k, s, x), t);
+  Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode());
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
 }