Add side conditions for inequalities of SHL. (#1472)
authorMathias Preiner <mathias.preiner@gmail.com>
Fri, 5 Jan 2018 01:28:42 +0000 (17:28 -0800)
committerGitHub <noreply@github.com>
Fri, 5 Jan 2018 01:28:42 +0000 (17:28 -0800)
src/theory/quantifiers/bv_inverter.cpp

index d5bc49eff789e2466dd65cc36c0a2cb2dcf06a96..f33229ae0106e33fd864fd0d57fabf72f7d3d307 100644 (file)
@@ -1320,7 +1320,7 @@ static Node getScBvLshr(bool pol,
       {
         /* x >> s < t
          * with side condition (synthesized):
-         * (not (= #x0 t)) */
+         * (not (= z t)) */
         scl = t.eqNode(z).notNode();
       }
       else
@@ -1338,7 +1338,7 @@ static Node getScBvLshr(bool pol,
       {
         /* s >> x < t
          * with side condition (synthesized):
-         * (not (= #x0 t)) */
+         * (not (= z t)) */
         scl = t.eqNode(z).notNode();
       }
       else
@@ -1835,6 +1835,28 @@ static Node getScBvAshr(bool pol,
   return sc;
 }
 
+static Node naiveShlSc1(Kind litk, Node s, Node t)
+{
+  unsigned w;
+  NodeBuilder<> nb(OR);
+  NodeManager *nm;
+
+  nm = NodeManager::currentNM();
+
+  w = bv::utils::getSize(s);
+  Assert(w == bv::utils::getSize(t));
+
+  nb << nm->mkNode(litk, s, t);
+  for (unsigned i = 1; i <= w; i++)
+  {
+    Node sw = bv::utils::mkConst(w, i);
+    nb << nm->mkNode(litk, nm->mkNode(BITVECTOR_SHL, s, sw), t);
+  }
+  if (nb.getNumChildren() == 1)
+    return nb[0];
+  return nb.constructNode();
+}
+
 static Node getScBvShl(bool pol,
                        Kind litk,
                        Kind k,
@@ -1844,9 +1866,11 @@ static Node getScBvShl(bool pol,
                        Node t)
 {
   Assert(k == BITVECTOR_SHL);
+  Assert(litk == EQUAL || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT
+         || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT);
 
   NodeManager* nm = NodeManager::currentNM();
-  Node scl, scr;
+  Node scl;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
   Node z = bv::utils::mkZero(w);
@@ -1884,7 +1908,6 @@ static Node getScBvShl(bool pol,
             nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
 
         scl = nm->mkNode(OR, o1, o2, o3);
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
       }
       else
       {
@@ -1897,7 +1920,6 @@ static Node getScBvShl(bool pol,
         scl = nm->mkNode(OR,
             t.eqNode(z).notNode(),
             nm->mkNode(BITVECTOR_ULT, s, ww));
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, x, s), t);
       }
     }
     else
@@ -1926,7 +1948,6 @@ static Node getScBvShl(bool pol,
         }
         nb << t.eqNode(z);
         scl = nb.constructNode();
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
       }
       else
       {
@@ -1934,15 +1955,200 @@ static Node getScBvShl(bool pol,
          * with side condition:
          * s != 0 || t != 0  */
         scl = nm->mkNode(OR, s.eqNode(z).notNode(), t.eqNode(z).notNode());
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, s, x), t);
+      }
+    }
+  }
+  else if (litk == BITVECTOR_ULT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x << s < t
+         * with side condition (synthesized):
+         * (not (= t z)) */
+        scl = t.eqNode(z).notNode();
+      }
+      else
+      {
+        /* x << s >= t
+         * with side condition (synthesized):
+         * (bvuge (bvshl ones s) t) */
+        Node shl = nm->mkNode(BITVECTOR_SHL, bv::utils::mkOnes(w), s);
+        scl = nm->mkNode(BITVECTOR_UGE, shl, t);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s << x < t
+         * with side condition (synthesized):
+         * (not (= t z)) */
+        scl = t.eqNode(z).notNode();
+      }
+      else
+      {
+        /* s << x >= t
+         * with side condition:
+         * (or (bvuge (bvshl s i) t) ...)
+         * for i in 0..w-1 */
+        scl = naiveShlSc1(BITVECTOR_UGE, s, t);
+      }
+    }
+  }
+  else if (litk == BITVECTOR_UGT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x << s > t
+         * with side condition (synthesized):
+         * (bvult t (bvshl ones s)) */
+        Node shl = nm->mkNode(BITVECTOR_SHL, bv::utils::mkOnes(w), s);
+        scl = nm->mkNode(BITVECTOR_ULT, t, shl);
+      }
+      else
+      {
+        /* x << s <= t
+         * with side condition:
+         * true (no side condition) */
+        scl = nm->mkConst<bool>(true);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s << x > t
+         * with side condition:
+         * (or (bvugt (bvshl s i) t) ...)
+         * for i in 0..w-1 */
+        scl = naiveShlSc1(BITVECTOR_UGT, s, t);
+      }
+      else
+      {
+        /* s << x <= t
+         * with side condition:
+         * true (no side condition) */
+        scl = nm->mkConst<bool>(true);
+      }
+    }
+  }
+  else if (litk == BITVECTOR_SLT)
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x << s < t
+         * with side condition (synthesized):
+         * (bvslt (bvshl (bvlshr min_val s) s) t)
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, min, s);
+        Node shl = nm->mkNode(BITVECTOR_SHL, lshr, s);
+        scl = nm->mkNode(BITVECTOR_SLT, shl, t);
+      }
+      else
+      {
+        /* x << s >= t
+         * with side condition (synthesized):
+         * (bvsge (bvand (bvshl max_val s) max_val) t)
+         * where
+         * max_val is the signed maximum value */
+        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
+        BitVector bv_max_val = BitVector(1).concat(bv_ones);
+        Node max = bv::utils::mkConst(bv_max_val);
+        Node shl = nm->mkNode(BITVECTOR_SHL, max, s);
+        scl = nm->mkNode(BITVECTOR_SGE, nm->mkNode(BITVECTOR_AND, shl, max), t);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s << x < t
+         * with side condition (synthesized):
+         * (bvult (bvshl min_val s) (bvadd t min_val))
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        Node shl = nm->mkNode(BITVECTOR_SHL, min, s);
+        Node add = nm->mkNode(BITVECTOR_PLUS, t, min);
+        scl = nm->mkNode(BITVECTOR_ULT, shl, add);
+      }
+      else
+      {
+        /* s << x >= t
+         * with side condition:
+         * (or (bvsge (bvshl s i) t) ...)
+         * for i in 0..w-1 */
+        scl = naiveShlSc1(BITVECTOR_SGE, s, t);
       }
     }
   }
   else
   {
-    return Node::null();
+    Assert(litk == BITVECTOR_SGT);
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x << s > t
+         * with side condition (synthesized):
+         * (bvslt t (bvand (bvshl max_val s) max_val))
+         * where
+         * max_val is the signed maximum value */
+        BitVector bv_ones = bv::utils::mkBitVectorOnes(w - 1);
+        BitVector bv_max_val = BitVector(1).concat(bv_ones);
+        Node max = bv::utils::mkConst(bv_max_val);
+        Node shl = nm->mkNode(BITVECTOR_SHL, max, s);
+        scl = nm->mkNode(BITVECTOR_SLT, t, nm->mkNode(BITVECTOR_AND, shl, max));
+      }
+      else
+      {
+        /* x << s <= t
+         * with side condition (synthesized):
+         * (bvult (bvlshr t (bvlshr t s)) min_val)
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        Node ts = nm->mkNode(BITVECTOR_LSHR, t, s);
+        scl = nm->mkNode(BITVECTOR_ULT, nm->mkNode(BITVECTOR_LSHR, t, ts), min);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s << x > t
+         * with side condition:
+         * (or (bvsgt (bvshl s i) t) ...)
+         * for i in 0..w-1 */
+        scl = naiveShlSc1(BITVECTOR_SGT, s, t);
+      }
+      else
+      {
+        /* s << x <= t
+         * with side condition (synthesized):
+         * (bvult (bvlshr t s) min_val)
+         * where
+         * min_val is the signed minimum value */
+        BitVector bv_min_val = BitVector(w).setBit(w - 1);
+        Node min = bv::utils::mkConst(bv_min_val);
+        scl = nm->mkNode(BITVECTOR_ULT, nm->mkNode(BITVECTOR_LSHR, t, s), min);
+      }
+    }
   }
-  Node sc = nm->mkNode(IMPLIES, scl, scr);
+  Node scr =
+    nm->mkNode(litk, idx == 0 ? nm->mkNode(k, x, s) : nm->mkNode(k, s, x), t);
+  Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode());
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
 }