Add side conditions for inequalities of LSHR. (#1462)
authorMathias Preiner <mathias.preiner@gmail.com>
Tue, 2 Jan 2018 19:21:12 +0000 (11:21 -0800)
committerGitHub <noreply@github.com>
Tue, 2 Jan 2018 19:21:12 +0000 (11:21 -0800)
src/theory/quantifiers/bv_inverter.cpp

index cce56ab011fc5e71cf36009da49b6c258e276c68..c1e59e3c02ed1abab71dbffab8dafa774a3074e1 100644 (file)
@@ -788,9 +788,10 @@ static Node getScBvLshr(bool pol,
                         Node t)
 {
   Assert(k == BITVECTOR_LSHR);
+  Assert(litk == EQUAL || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT);
 
   NodeManager* nm = NodeManager::currentNM();
-  Node scl, scr;
+  Node scl;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
   Node z = bv::utils::mkZero(w);
@@ -824,7 +825,6 @@ static Node getScBvLshr(bool pol,
             nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
 
         scl = nm->mkNode(OR, o1, o2, o3);
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
       }
       else
       {
@@ -837,7 +837,6 @@ static Node getScBvLshr(bool pol,
         scl = nm->mkNode(OR,
             t.eqNode(z).notNode(),
             nm->mkNode(BITVECTOR_ULT, s, ww));
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, x, s), t);
       }
     }
     else
@@ -868,7 +867,6 @@ static Node getScBvLshr(bool pol,
         }
         nb << t.eqNode(z);
         scl = nb.constructNode();
-        scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
       }
       else
       {
@@ -876,15 +874,102 @@ static Node getScBvLshr(bool pol,
          * with side condition:
          * s != 0 || t != 0  */
         scl = nm->mkNode(OR, s.eqNode(z).notNode(), t.eqNode(z).notNode());
-        scr = nm->mkNode(DISTINCT, nm->mkNode(k, s, x), t);
       }
     }
   }
-  else
+  else if (litk == BITVECTOR_ULT)
   {
-    return Node::null();
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x >> s < t
+         * with side condition (synthesized):
+         * (not (= #x0 t)) */
+        scl = t.eqNode(z).notNode();
+      }
+      else
+      {
+        /* x >> s >= t
+         * with side condition (synthesized):
+         * (= (bvlshr (bvshl t s) s) t) */
+        Node ts = nm->mkNode(BITVECTOR_SHL, t, s);
+        scl = nm->mkNode(BITVECTOR_LSHR, ts, s).eqNode(t);
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x < t
+         * with side condition (synthesized):
+         * (not (= #x0 t)) */
+        scl = t.eqNode(z).notNode();
+      }
+      else
+      {
+        /* s >> x >= t
+         * with side condition (synthesized):
+         * (bvuge s t) */
+        scl = nm->mkNode(BITVECTOR_UGE, s, t);
+      }
+    }
   }
-  Node sc = nm->mkNode(IMPLIES, scl, scr);
+  else /* litk == BITVECTOR_SLT */
+  {
+    if (idx == 0)
+    {
+      if (pol)
+      {
+        /* x >> s < t
+         * with side condition (synthesized):
+         * (bvslt (bvlshr (bvnot (bvneg t)) s) t) */
+        Node nnt = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, t));
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, nnt, s);
+        scl = nm->mkNode(BITVECTOR_SLT, lshr, t);
+      }
+      else
+      {
+        /* x >> s >= t
+         * with side condition:
+         * (=> (not (= s z)) (bvsge (bvlshr ones s) t)) */
+        Node ones = bv::utils::mkOnes(w);
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, ones, s);
+        Node nz = s.eqNode(z).notNode();
+        scl = nz.impNode(nm->mkNode(BITVECTOR_SGE, lshr, t));
+      }
+    }
+    else
+    {
+      if (pol)
+      {
+        /* s >> x < t
+         * with side condition (synthesized):
+         * (or (bvslt s t) (bvslt z t)) */
+        Node st = nm->mkNode(BITVECTOR_SLT, s, t);
+        Node zt = nm->mkNode(BITVECTOR_SLT, z, t);
+        scl = st.orNode(zt);
+      }
+      else
+      {
+        /* s >> x >= t
+         * with side condition:
+         * (and
+         *  (=> (bvslt s z) (bvsge (bvlshr s one) t))
+         *  (=> (bvsge s z) (bvsge s t))
+         * ) */
+        Node one = bv::utils::mkConst(w, 1);
+        Node sz = nm->mkNode(BITVECTOR_SLT, s, z);
+        Node lshr = nm->mkNode(BITVECTOR_LSHR, s, one);
+        Node sge1 = nm->mkNode(BITVECTOR_SGE, lshr, t);
+        Node sge2 = nm->mkNode(BITVECTOR_SGE, s, t);
+        scl = sz.impNode(sge1).andNode(sz.notNode().impNode(sge2));
+      }
+    }
+  }
+  Node scr =
+    nm->mkNode(litk, idx == 0 ? nm->mkNode(k, x, s) : nm->mkNode(k, s, x), t);
+  Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode());
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
 }