Add missing side conditions for SHL, LSHR, ASHR for CBQI BV. (#1441)
authorAina Niemetz <aina.niemetz@gmail.com>
Thu, 14 Dec 2017 03:10:16 +0000 (19:10 -0800)
committerGitHub <noreply@github.com>
Thu, 14 Dec 2017 03:10:16 +0000 (19:10 -0800)
This adds side conditions for operators BITVECTOR_SHL, BITVECTOR_LSHR and
BITVECTOR_ASHR for index = 1, i.e., s << x = t and s >> x = t. Previously, we treated
 these cases as non-invertible.

src/theory/quantifiers/bv_inverter.cpp
test/unit/theory/theory_quantifiers_bv_inverter_white.h

index bad26d14f818851845a6b987d414125d46799d8a..d3fa0715da51320de5cfb048d5d9b61cf3496cf7 100644 (file)
@@ -115,9 +115,9 @@ static bool isInvertible(Kind k, unsigned index)
       ||  k == BITVECTOR_AND
       ||  k == BITVECTOR_OR
       ||  k == BITVECTOR_XOR
-      || (k == BITVECTOR_LSHR && index == 0)
-      || (k == BITVECTOR_ASHR && index == 0)
-      || (k == BITVECTOR_SHL && index == 0);
+      || k == BITVECTOR_LSHR
+      || k == BITVECTOR_ASHR
+      || k == BITVECTOR_SHL;
 }
 
 Node BvInverter::getPathToPv(
@@ -463,34 +463,63 @@ static Node getScBvAndOr(Kind k, unsigned idx, Node x, Node s, Node t)
 static Node getScBvLshr(Kind k, unsigned idx, Node x, Node s, Node t)
 {
   Assert(k == BITVECTOR_LSHR);
-  Assert(idx == 0);
 
   NodeManager* nm = NodeManager::currentNM();
+  Node scl, scr;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
+  Node z = bv::utils::mkZero(w);
   
-  /* x >> s = t
-   * with side condition:
-   * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0)
-   * ->
-   * s = 0 || (s < w && ((z o t) << (z o s))[2w-1 : w] = z) || (s >= w && t = 0)
-   * with w = getSize(t) = getSize(s)
-   * and z = 0 with getSize(z) = w  */
+  if (idx == 0)
+  {
+    /* x >> s = t
+     * with side condition:
+     * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0)
+     * ->
+     * s = 0 || (s < w && ((z o t) << (z o s))[2w-1 : w] = z) || (s >= w && t = 0)
+     * with w = getSize(t) = getSize(s)
+     * and z = 0 with getSize(z) = w  */
 
-  Node z = bv::utils::mkZero(w);
-  Node ww = bv::utils::mkConst(w, w);
+    Node ww = bv::utils::mkConst(w, w);
 
-  Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t);
-  Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s);
-  Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
-  Node ext = bv::utils::mkExtract(shl, 2*w-1, w);
+    Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t);
+    Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s);
+    Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
+    Node ext = bv::utils::mkExtract(shl, 2*w-1, w);
 
-  Node o1 = s.eqNode(z);
-  Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
-  Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
+    Node o1 = s.eqNode(z);
+    Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
+    Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
 
-  Node scl = nm->mkNode(OR, o1, o2, o3);
-  Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+    scl = nm->mkNode(OR, o1, o2, o3);
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+  }
+  else
+  {
+    /* s >> x = t
+     * with side condition:
+     * t = 0
+     * ||
+     * s = t
+     * || 
+     * \/ (t[w-1-i:0] = s[w-1:i] && t[w-1:w-i] = 0) for 0 < i < w
+     * where
+     * w = getSize(s) = getSize(t)
+     */
+    NodeBuilder<> nb(nm, OR);
+    nb << nm->mkNode(EQUAL, t, s);
+    for (unsigned i = 1; i < w; ++i)
+    {
+      nb << nm->mkNode(AND,
+          nm->mkNode(EQUAL,
+            bv::utils::mkExtract(t, w-1-i, 0), bv::utils::mkExtract(s, w-1, i)),
+          nm->mkNode(EQUAL,
+            bv::utils::mkExtract(t, w-1, w-i), bv::utils::mkZero(i)));
+    }
+    nb << t.eqNode(z);
+    scl = nb.constructNode();
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
+  }
   Node sc = nm->mkNode(IMPLIES, scl, scr);
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
@@ -499,49 +528,93 @@ static Node getScBvLshr(Kind k, unsigned idx, Node x, Node s, Node t)
 static Node getScBvAshr(Kind k, unsigned idx, Node x, Node s, Node t)
 {
   Assert(k == BITVECTOR_ASHR);
-  Assert(idx == 0);
 
   NodeManager* nm = NodeManager::currentNM();
+  Node scl, scr;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
-  
-  /* x >> s = t
-   * with side condition:
-   * s = 0
-   * ||
-   * (s < w && (((z o t) << (z o s))[2w-1:w-1] = z
-   *            ||
-   *            ((~z o t) << (z o s))[2w-1:w-1] = ~z))
-   * ||
-   * (s >= w && (t = 0 || t = ~0))
-   * with w = getSize(t) = getSize(s)
-   * and z = 0 with getSize(z) = w  */
-  
   Node z = bv::utils::mkZero(w);
-  Node zz = bv::utils::mkZero(w+1);
   Node n = bv::utils::mkOnes(w);
-  Node nn = bv::utils::mkOnes(w+1);
-  Node ww = bv::utils::mkConst(w, w);
-
-  Node z_o_t = bv::utils::mkConcat(z, t);
-  Node z_o_s = bv::utils::mkConcat(z, s);
-  Node n_o_t = bv::utils::mkConcat(n, t);
-
-  Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
-  Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s);
-  Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1);
-  Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1);
-
-  Node o1 = s.eqNode(z);
-  Node o2 = nm->mkNode(AND,
-      nm->mkNode(BITVECTOR_ULT, s, ww),
-      nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn)));
-  Node o3 = nm->mkNode(AND,
-      nm->mkNode(BITVECTOR_UGE, s, ww),
-      nm->mkNode(OR, t.eqNode(z), t.eqNode(n)));
-
-  Node scl = nm->mkNode(OR, o1, o2, o3);
-  Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+  
+  if (idx == 0)
+  {
+    /* x >> s = t
+     * with side condition:
+     * s = 0
+     * ||
+     * (s < w && (((z o t) << (z o s))[2w-1:w-1] = z
+     *            ||
+     *            ((~z o t) << (z o s))[2w-1:w-1] = ~z))
+     * ||
+     * (s >= w && (t = 0 || t = ~0))
+     * with w = getSize(t) = getSize(s)
+     * and z = 0 with getSize(z) = w  */
+    
+    Node zz = bv::utils::mkZero(w+1);
+    Node nn = bv::utils::mkOnes(w+1);
+    Node ww = bv::utils::mkConst(w, w);
+
+    Node z_o_t = bv::utils::mkConcat(z, t);
+    Node z_o_s = bv::utils::mkConcat(z, s);
+    Node n_o_t = bv::utils::mkConcat(n, t);
+
+    Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
+    Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s);
+    Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1);
+    Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1);
+
+    Node o1 = s.eqNode(z);
+    Node o2 = nm->mkNode(AND,
+        nm->mkNode(BITVECTOR_ULT, s, ww),
+        nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn)));
+    Node o3 = nm->mkNode(AND,
+        nm->mkNode(BITVECTOR_UGE, s, ww),
+        nm->mkNode(OR, t.eqNode(z), t.eqNode(n)));
+
+    scl = nm->mkNode(OR, o1, o2, o3);
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+  }
+  else
+  {
+    /* s >> x = t
+     * with side condition:
+     * (s[w-1:w-1] = 0 && t = 0)
+     * ||
+     * (s[w-1:w-1] = 1 && t == ~0)
+     * ||
+     * s = t
+     * || 
+     * \/ (t[w-1-i:0] = s[w-1:i]
+     *     && ((s[w-1:w-1] = 0 && t[w-1:w-i] = 0)
+     *         ||
+     *         (s[w-1:w-1] = 1 &&  t[w-1:w-i] = ~0)))
+     * for 0 < i < w
+     * where
+     * w = getSize(s) = getSize(t)
+     */
+    Node msbz = bv::utils::mkExtract(s, w-1, w-1).eqNode(bv::utils::mkZero(1));
+    Node msbn = bv::utils::mkExtract(s, w-1, w-1).eqNode(bv::utils::mkOnes(1));
+    NodeBuilder<> nb(nm, OR);
+    nb << nm->mkNode(EQUAL, t, s);
+    for (unsigned i = 1; i < w; ++i)
+    {
+
+      Node ext = bv::utils::mkExtract(t, w-1, w-i);
+
+      Node o1 = nm->mkNode(AND, msbz, ext.eqNode(bv::utils::mkZero(i)));
+      Node o2 = nm->mkNode(AND, msbn, ext.eqNode(bv::utils::mkOnes(i)));
+      Node o = nm->mkNode(OR, o1, o2);
+
+      Node e = nm->mkNode(EQUAL,
+          bv::utils::mkExtract(t, w-1-i, 0), bv::utils::mkExtract(s, w-1, i));
+
+      nb << nm->mkNode(AND, e, o);
+    }
+    nb << nm->mkNode(AND, msbz, t.eqNode(z));
+    nb << nm->mkNode(AND, msbn, t.eqNode(n));
+    scl = nb.constructNode();
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
+  }
   Node sc = nm->mkNode(IMPLIES, scl, scr);
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
@@ -550,36 +623,65 @@ static Node getScBvAshr(Kind k, unsigned idx, Node x, Node s, Node t)
 static Node getScBvShl(Kind k, unsigned idx, Node x, Node s, Node t)
 {
   Assert(k == BITVECTOR_SHL);
-  Assert(idx == 0);
 
   NodeManager* nm = NodeManager::currentNM();
+  Node scl, scr;
   unsigned w = bv::utils::getSize(s);
   Assert(w == bv::utils::getSize(t));
+  Node z = bv::utils::mkConst(w, 0u);
 
-  /* x << s = t
-   * with side condition:
-   * (s = 0 || ctz(t) >= s)
-   * <->
-   * (s = 0 || (s < w && ((t o z) >> (z o s))[w-1:0] = z) || (s >= w && t = 0)
-   *
-   * where
-   * w = getSize(s) = getSize(t) = getSize (z) && z = 0
-   */
+  if (idx == 0)
+  {
+    /* x << s = t
+     * with side condition:
+     * (s = 0 || ctz(t) >= s)
+     * <->
+     * (s = 0 || (s < w && ((t o z) >> (z o s))[w-1:0] = z) || (s >= w && t = 0)
+     *
+     * where
+     * w = getSize(s) = getSize(t) = getSize (z) && z = 0
+     */
 
-  Node z = bv::utils::mkConst(w, 0u);
-  Node ww = bv::utils::mkConst(w, w);
+    Node ww = bv::utils::mkConst(w, w);
 
-  Node shr = nm->mkNode(BITVECTOR_LSHR,
-      bv::utils::mkConcat(t, z),
-      bv::utils::mkConcat(z, s));
-  Node ext = bv::utils::mkExtract(shr, w - 1, 0);
+    Node shr = nm->mkNode(BITVECTOR_LSHR,
+        bv::utils::mkConcat(t, z),
+        bv::utils::mkConcat(z, s));
+    Node ext = bv::utils::mkExtract(shr, w - 1, 0);
 
-  Node o1 = nm->mkNode(EQUAL, s, z);
-  Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
-  Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
+    Node o1 = nm->mkNode(EQUAL, s, z);
+    Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z));
+    Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z));
 
-  Node scl = nm->mkNode(OR, o1, o2, o3);
-  Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+    scl = nm->mkNode(OR, o1, o2, o3);
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
+  }
+  else
+  {
+    /* s << x = t
+     * with side condition:
+     * t = 0
+     * ||
+     * s = t
+     * || 
+     * \/ (t[w-1:i] = s[w-1-i:0] && t[i-1:0] = 0) for 0 < i < w
+     * where
+     * w = getSize(s) = getSize(t)
+     */
+    NodeBuilder<> nb(nm, OR);
+    nb << nm->mkNode(EQUAL, t, s);
+    for (unsigned i = 1; i < w; ++i)
+    {
+      nb << nm->mkNode(AND,
+          nm->mkNode(EQUAL,
+            bv::utils::mkExtract(t, w-1, i), bv::utils::mkExtract(s, w-1-i, 0)),
+          nm->mkNode(EQUAL,
+            bv::utils::mkExtract(t, i-1, 0), bv::utils::mkZero(i)));
+    }
+    nb << t.eqNode(z);
+    scl = nb.constructNode();
+    scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t);
+  }
   Node sc = nm->mkNode(IMPLIES, scl, scr);
   Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl;
   return sc;
index 04c97a831afeec7dbe5294a5907bd296f5f4f175..ce01c17e40937fdda010da9c8678d5c652c4566a 100644 (file)
@@ -84,10 +84,6 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
            || k == kind::BITVECTOR_ASHR
            || k == kind::BITVECTOR_SHL);
     Assert(k != kind::BITVECTOR_UREM_TOTAL || idx == 1);
-    Assert((k != kind::BITVECTOR_LSHR
-            && k != kind::BITVECTOR_ASHR
-            && k != kind::BITVECTOR_SHL)
-           || idx == 0);
 
     Node sc = getsc(k, idx, d_sk, d_s, d_t);
     Kind ksc = sc.getKind();
@@ -98,6 +94,12 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
     Node scr = d_nm->mkNode(kind::EXISTS, d_bvarlist, body);
     Expr a = d_nm->mkNode(kind::DISTINCT, sc[0], scr).toExpr();
     Result res = d_smt->checkSat(a);
+    if (res.d_sat == Result::SAT)
+    {
+      std::cout << std::endl << "s " << d_smt->getValue(d_s.toExpr()) << std::endl;
+      std::cout << "t " << d_smt->getValue(d_t.toExpr()) << std::endl;
+      std::cout << "x " << d_smt->getValue(d_x.toExpr()) << std::endl;
+    }
     TS_ASSERT(res.d_sat == Result::UNSAT);
   }
 
@@ -110,6 +112,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
     d_nm = NodeManager::fromExprManager(d_em);
     d_smt = new SmtEngine(d_em);
     d_smt->setOption("cbqi-bv", CVC4::SExpr(false));
+    d_smt->setOption("produce-models", CVC4::SExpr(true));
     d_scope = new SmtScope(d_smt);
 
     d_s = d_nm->mkVar("s", d_nm->mkBitVectorType(4));
@@ -243,8 +246,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
 
   void testGetScBvLshr1()
   {
-    TS_ASSERT_THROWS(runTest(BITVECTOR_LSHR, 1, getScBvLshr),
-                     AssertionException);
+    runTest(BITVECTOR_LSHR, 1, getScBvLshr);
   }
 
   void testGetScBvAshr0()
@@ -254,8 +256,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
 
   void testGetScBvAshr1()
   {
-    TS_ASSERT_THROWS(runTest(BITVECTOR_ASHR, 1, getScBvAshr),
-                     AssertionException);
+    runTest(BITVECTOR_ASHR, 1, getScBvAshr);
   }
 
   void testGetScBvShl0()
@@ -265,8 +266,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite
 
   void testGetScBvShl1()
   {
-    TS_ASSERT_THROWS(runTest(BITVECTOR_SHL, 1, getScBvShl),
-                     AssertionException);
+    runTest(BITVECTOR_SHL, 1, getScBvShl);
   }
 
 };