From: Aina Niemetz Date: Thu, 14 Dec 2017 03:10:16 +0000 (-0800) Subject: Add missing side conditions for SHL, LSHR, ASHR for CBQI BV. (#1441) X-Git-Tag: cvc5-1.0.0~5414 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=26214b7b02e90fca270e6bac7d6b64ea1a6d723a;p=cvc5.git Add missing side conditions for SHL, LSHR, ASHR for CBQI BV. (#1441) This adds side conditions for operators BITVECTOR_SHL, BITVECTOR_LSHR and BITVECTOR_ASHR for index = 1, i.e., s << x = t and s >> x = t. Previously, we treated these cases as non-invertible. --- diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index bad26d14f..d3fa0715d 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -115,9 +115,9 @@ static bool isInvertible(Kind k, unsigned index) || k == BITVECTOR_AND || k == BITVECTOR_OR || k == BITVECTOR_XOR - || (k == BITVECTOR_LSHR && index == 0) - || (k == BITVECTOR_ASHR && index == 0) - || (k == BITVECTOR_SHL && index == 0); + || k == BITVECTOR_LSHR + || k == BITVECTOR_ASHR + || k == BITVECTOR_SHL; } Node BvInverter::getPathToPv( @@ -463,34 +463,63 @@ static Node getScBvAndOr(Kind k, unsigned idx, Node x, Node s, Node t) static Node getScBvLshr(Kind k, unsigned idx, Node x, Node s, Node t) { Assert(k == BITVECTOR_LSHR); - Assert(idx == 0); NodeManager* nm = NodeManager::currentNM(); + Node scl, scr; unsigned w = bv::utils::getSize(s); Assert(w == bv::utils::getSize(t)); + Node z = bv::utils::mkZero(w); - /* x >> s = t - * with side condition: - * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0) - * -> - * s = 0 || (s < w && ((z o t) << (z o s))[2w-1 : w] = z) || (s >= w && t = 0) - * with w = getSize(t) = getSize(s) - * and z = 0 with getSize(z) = w */ + if (idx == 0) + { + /* x >> s = t + * with side condition: + * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0) + * -> + * s = 0 || (s < w && ((z o t) << (z o s))[2w-1 : w] = z) || (s >= w && t = 0) + * with w = getSize(t) = getSize(s) + * and z = 0 with getSize(z) = w */ - Node z = bv::utils::mkZero(w); - Node ww = bv::utils::mkConst(w, w); + Node ww = bv::utils::mkConst(w, w); - Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t); - Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s); - Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); - Node ext = bv::utils::mkExtract(shl, 2*w-1, w); + Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t); + Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s); + Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); + Node ext = bv::utils::mkExtract(shl, 2*w-1, w); - Node o1 = s.eqNode(z); - Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); - Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); + Node o1 = s.eqNode(z); + Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); + Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); - Node scl = nm->mkNode(OR, o1, o2, o3); - Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + scl = nm->mkNode(OR, o1, o2, o3); + scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + } + else + { + /* s >> x = t + * with side condition: + * t = 0 + * || + * s = t + * || + * \/ (t[w-1-i:0] = s[w-1:i] && t[w-1:w-i] = 0) for 0 < i < w + * where + * w = getSize(s) = getSize(t) + */ + NodeBuilder<> nb(nm, OR); + nb << nm->mkNode(EQUAL, t, s); + for (unsigned i = 1; i < w; ++i) + { + nb << nm->mkNode(AND, + nm->mkNode(EQUAL, + bv::utils::mkExtract(t, w-1-i, 0), bv::utils::mkExtract(s, w-1, i)), + nm->mkNode(EQUAL, + bv::utils::mkExtract(t, w-1, w-i), bv::utils::mkZero(i))); + } + nb << t.eqNode(z); + scl = nb.constructNode(); + scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t); + } Node sc = nm->mkNode(IMPLIES, scl, scr); Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl; return sc; @@ -499,49 +528,93 @@ static Node getScBvLshr(Kind k, unsigned idx, Node x, Node s, Node t) static Node getScBvAshr(Kind k, unsigned idx, Node x, Node s, Node t) { Assert(k == BITVECTOR_ASHR); - Assert(idx == 0); NodeManager* nm = NodeManager::currentNM(); + Node scl, scr; unsigned w = bv::utils::getSize(s); Assert(w == bv::utils::getSize(t)); - - /* x >> s = t - * with side condition: - * s = 0 - * || - * (s < w && (((z o t) << (z o s))[2w-1:w-1] = z - * || - * ((~z o t) << (z o s))[2w-1:w-1] = ~z)) - * || - * (s >= w && (t = 0 || t = ~0)) - * with w = getSize(t) = getSize(s) - * and z = 0 with getSize(z) = w */ - Node z = bv::utils::mkZero(w); - Node zz = bv::utils::mkZero(w+1); Node n = bv::utils::mkOnes(w); - Node nn = bv::utils::mkOnes(w+1); - Node ww = bv::utils::mkConst(w, w); - - Node z_o_t = bv::utils::mkConcat(z, t); - Node z_o_s = bv::utils::mkConcat(z, s); - Node n_o_t = bv::utils::mkConcat(n, t); - - Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); - Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s); - Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1); - Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1); - - Node o1 = s.eqNode(z); - Node o2 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_ULT, s, ww), - nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn))); - Node o3 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_UGE, s, ww), - nm->mkNode(OR, t.eqNode(z), t.eqNode(n))); - - Node scl = nm->mkNode(OR, o1, o2, o3); - Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + + if (idx == 0) + { + /* x >> s = t + * with side condition: + * s = 0 + * || + * (s < w && (((z o t) << (z o s))[2w-1:w-1] = z + * || + * ((~z o t) << (z o s))[2w-1:w-1] = ~z)) + * || + * (s >= w && (t = 0 || t = ~0)) + * with w = getSize(t) = getSize(s) + * and z = 0 with getSize(z) = w */ + + Node zz = bv::utils::mkZero(w+1); + Node nn = bv::utils::mkOnes(w+1); + Node ww = bv::utils::mkConst(w, w); + + Node z_o_t = bv::utils::mkConcat(z, t); + Node z_o_s = bv::utils::mkConcat(z, s); + Node n_o_t = bv::utils::mkConcat(n, t); + + Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); + Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s); + Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1); + Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1); + + Node o1 = s.eqNode(z); + Node o2 = nm->mkNode(AND, + nm->mkNode(BITVECTOR_ULT, s, ww), + nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn))); + Node o3 = nm->mkNode(AND, + nm->mkNode(BITVECTOR_UGE, s, ww), + nm->mkNode(OR, t.eqNode(z), t.eqNode(n))); + + scl = nm->mkNode(OR, o1, o2, o3); + scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + } + else + { + /* s >> x = t + * with side condition: + * (s[w-1:w-1] = 0 && t = 0) + * || + * (s[w-1:w-1] = 1 && t == ~0) + * || + * s = t + * || + * \/ (t[w-1-i:0] = s[w-1:i] + * && ((s[w-1:w-1] = 0 && t[w-1:w-i] = 0) + * || + * (s[w-1:w-1] = 1 && t[w-1:w-i] = ~0))) + * for 0 < i < w + * where + * w = getSize(s) = getSize(t) + */ + Node msbz = bv::utils::mkExtract(s, w-1, w-1).eqNode(bv::utils::mkZero(1)); + Node msbn = bv::utils::mkExtract(s, w-1, w-1).eqNode(bv::utils::mkOnes(1)); + NodeBuilder<> nb(nm, OR); + nb << nm->mkNode(EQUAL, t, s); + for (unsigned i = 1; i < w; ++i) + { + + Node ext = bv::utils::mkExtract(t, w-1, w-i); + + Node o1 = nm->mkNode(AND, msbz, ext.eqNode(bv::utils::mkZero(i))); + Node o2 = nm->mkNode(AND, msbn, ext.eqNode(bv::utils::mkOnes(i))); + Node o = nm->mkNode(OR, o1, o2); + + Node e = nm->mkNode(EQUAL, + bv::utils::mkExtract(t, w-1-i, 0), bv::utils::mkExtract(s, w-1, i)); + + nb << nm->mkNode(AND, e, o); + } + nb << nm->mkNode(AND, msbz, t.eqNode(z)); + nb << nm->mkNode(AND, msbn, t.eqNode(n)); + scl = nb.constructNode(); + scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t); + } Node sc = nm->mkNode(IMPLIES, scl, scr); Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl; return sc; @@ -550,36 +623,65 @@ static Node getScBvAshr(Kind k, unsigned idx, Node x, Node s, Node t) static Node getScBvShl(Kind k, unsigned idx, Node x, Node s, Node t) { Assert(k == BITVECTOR_SHL); - Assert(idx == 0); NodeManager* nm = NodeManager::currentNM(); + Node scl, scr; unsigned w = bv::utils::getSize(s); Assert(w == bv::utils::getSize(t)); + Node z = bv::utils::mkConst(w, 0u); - /* x << s = t - * with side condition: - * (s = 0 || ctz(t) >= s) - * <-> - * (s = 0 || (s < w && ((t o z) >> (z o s))[w-1:0] = z) || (s >= w && t = 0) - * - * where - * w = getSize(s) = getSize(t) = getSize (z) && z = 0 - */ + if (idx == 0) + { + /* x << s = t + * with side condition: + * (s = 0 || ctz(t) >= s) + * <-> + * (s = 0 || (s < w && ((t o z) >> (z o s))[w-1:0] = z) || (s >= w && t = 0) + * + * where + * w = getSize(s) = getSize(t) = getSize (z) && z = 0 + */ - Node z = bv::utils::mkConst(w, 0u); - Node ww = bv::utils::mkConst(w, w); + Node ww = bv::utils::mkConst(w, w); - Node shr = nm->mkNode(BITVECTOR_LSHR, - bv::utils::mkConcat(t, z), - bv::utils::mkConcat(z, s)); - Node ext = bv::utils::mkExtract(shr, w - 1, 0); + Node shr = nm->mkNode(BITVECTOR_LSHR, + bv::utils::mkConcat(t, z), + bv::utils::mkConcat(z, s)); + Node ext = bv::utils::mkExtract(shr, w - 1, 0); - Node o1 = nm->mkNode(EQUAL, s, z); - Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); - Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); + Node o1 = nm->mkNode(EQUAL, s, z); + Node o2 = nm->mkNode(AND, nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); + Node o3 = nm->mkNode(AND, nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); - Node scl = nm->mkNode(OR, o1, o2, o3); - Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + scl = nm->mkNode(OR, o1, o2, o3); + scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + } + else + { + /* s << x = t + * with side condition: + * t = 0 + * || + * s = t + * || + * \/ (t[w-1:i] = s[w-1-i:0] && t[i-1:0] = 0) for 0 < i < w + * where + * w = getSize(s) = getSize(t) + */ + NodeBuilder<> nb(nm, OR); + nb << nm->mkNode(EQUAL, t, s); + for (unsigned i = 1; i < w; ++i) + { + nb << nm->mkNode(AND, + nm->mkNode(EQUAL, + bv::utils::mkExtract(t, w-1, i), bv::utils::mkExtract(s, w-1-i, 0)), + nm->mkNode(EQUAL, + bv::utils::mkExtract(t, i-1, 0), bv::utils::mkZero(i))); + } + nb << t.eqNode(z); + scl = nb.constructNode(); + scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t); + } Node sc = nm->mkNode(IMPLIES, scl, scr); Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl; return sc; diff --git a/test/unit/theory/theory_quantifiers_bv_inverter_white.h b/test/unit/theory/theory_quantifiers_bv_inverter_white.h index 04c97a831..ce01c17e4 100644 --- a/test/unit/theory/theory_quantifiers_bv_inverter_white.h +++ b/test/unit/theory/theory_quantifiers_bv_inverter_white.h @@ -84,10 +84,6 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite || k == kind::BITVECTOR_ASHR || k == kind::BITVECTOR_SHL); Assert(k != kind::BITVECTOR_UREM_TOTAL || idx == 1); - Assert((k != kind::BITVECTOR_LSHR - && k != kind::BITVECTOR_ASHR - && k != kind::BITVECTOR_SHL) - || idx == 0); Node sc = getsc(k, idx, d_sk, d_s, d_t); Kind ksc = sc.getKind(); @@ -98,6 +94,12 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite Node scr = d_nm->mkNode(kind::EXISTS, d_bvarlist, body); Expr a = d_nm->mkNode(kind::DISTINCT, sc[0], scr).toExpr(); Result res = d_smt->checkSat(a); + if (res.d_sat == Result::SAT) + { + std::cout << std::endl << "s " << d_smt->getValue(d_s.toExpr()) << std::endl; + std::cout << "t " << d_smt->getValue(d_t.toExpr()) << std::endl; + std::cout << "x " << d_smt->getValue(d_x.toExpr()) << std::endl; + } TS_ASSERT(res.d_sat == Result::UNSAT); } @@ -110,6 +112,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite d_nm = NodeManager::fromExprManager(d_em); d_smt = new SmtEngine(d_em); d_smt->setOption("cbqi-bv", CVC4::SExpr(false)); + d_smt->setOption("produce-models", CVC4::SExpr(true)); d_scope = new SmtScope(d_smt); d_s = d_nm->mkVar("s", d_nm->mkBitVectorType(4)); @@ -243,8 +246,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite void testGetScBvLshr1() { - TS_ASSERT_THROWS(runTest(BITVECTOR_LSHR, 1, getScBvLshr), - AssertionException); + runTest(BITVECTOR_LSHR, 1, getScBvLshr); } void testGetScBvAshr0() @@ -254,8 +256,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite void testGetScBvAshr1() { - TS_ASSERT_THROWS(runTest(BITVECTOR_ASHR, 1, getScBvAshr), - AssertionException); + runTest(BITVECTOR_ASHR, 1, getScBvAshr); } void testGetScBvShl0() @@ -265,8 +266,7 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite void testGetScBvShl1() { - TS_ASSERT_THROWS(runTest(BITVECTOR_SHL, 1, getScBvShl), - AssertionException); + runTest(BITVECTOR_SHL, 1, getScBvShl); } };