From 54fa2b48723ace784f2fa5e710aef6c2a38a7bd2 Mon Sep 17 00:00:00 2001 From: Aina Niemetz Date: Tue, 2 Jan 2018 20:30:04 -0800 Subject: [PATCH] Add side conditions for inequalities over BITVECTOR_UDIV for CBQI BV. (#1464) We now can handle all cases of (in|dis)equality over BITVECTOR_UREM. This also simplifies some of the side conditions for equality. --- src/theory/quantifiers/bv_inverter.cpp | 326 +++++++++++++++--- .../theory_quantifiers_bv_inverter_white.h | 1 - 2 files changed, 278 insertions(+), 49 deletions(-) diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index 67ce8d217..c794c693c 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -551,13 +551,15 @@ static Node getScBvUdiv(bool pol, Node t) { Assert(k == BITVECTOR_UDIV_TOTAL); + Assert(litk == EQUAL + || litk == BITVECTOR_ULT || litk == BITVECTOR_SLT + || litk == BITVECTOR_UGT || litk == BITVECTOR_SGT); NodeManager* nm = NodeManager::currentNM(); unsigned w = bv::utils::getSize(s); Assert (w == bv::utils::getSize(t)); - Node sc, scl, scr; + Node scl; Node z = bv::utils::mkZero(w); - Node n = bv::utils::mkOnes(w); if (litk == EQUAL) { @@ -566,74 +568,302 @@ static Node getScBvUdiv(bool pol, if (pol) { /* x udiv s = t - * with side condition: - * t = ~0 && (s = 0 || s = 1) - * || - * (t != ~0 && s != 0 && !umulo(s * t)) */ - Node one = bv::utils::mkOne(w); - Node o1 = nm->mkNode(AND, - t.eqNode(n), - nm->mkNode(OR, s.eqNode(z), s.eqNode(one))); - Node o2 = nm->mkNode(AND, - t.eqNode(n).notNode(), - s.eqNode(z).notNode(), - nm->mkNode(NOT, bv::utils::mkUmulo(s, t))); - - scl = nm->mkNode(OR, o1, o2); - scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); + * with side condition (synthesized): + * (= (bvudiv (bvmul s t) s) t) + * + * is equivalent to: + * (or + * (and (= t (bvnot z)) + * (or (= s z) (= s (_ bv1 w)))) + * (and (distinct t (bvnot z)) + * (distinct s z) + * (not (umulo s t)))) + * + * where umulo(s, t) is true if s * t produces and overflow + * and z = 0 with getSize(z) = w */ + Node mul = nm->mkNode(BITVECTOR_MULT, s, t); + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, s); + scl = nm->mkNode(EQUAL, div, t); } else { /* x udiv s != t * with side condition: - * s != 0 || t != ~0 */ - scl = nm->mkNode(OR, s.eqNode(z).notNode(), t.eqNode(n).notNode()); - scr = nm->mkNode(DISTINCT, nm->mkNode(k, x, s), t); + * (or (distinct s z) (distinct t ones)) + * where z = 0 with getSize(z) = w and ones = ~0 */ + Node ones = bv::utils::mkOnes(w); + scl = nm->mkNode(OR, s.eqNode(z).notNode(), t.eqNode(ones).notNode()); } - sc = nm->mkNode(IMPLIES, scl, scr); } else { if (pol) { /* s udiv x = t + * with side condition (synthesized): + * (= (bvudiv s (bvudiv s t)) t) + * + * is equivalent to: + * (or + * (= s t) + * (= t (bvnot z)) + * (and + * (bvuge s t) + * (or + * (= (bvurem s t) z) + * (bvule (bvadd (bvudiv s (bvadd t (_ bv1 w))) (_ bv1 w)) + * (bvudiv s t))) + * (=> (= s (bvnot (_ bv0 8))) (distinct t (_ bv0 8))))) + * + * where z = 0 with getSize(z) = w */ + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, s, t); + scl = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, s, div), t); + } + else + { + /* s udiv x != t + * true (no side condition) */ + scl = nm->mkConst(true); + } + } + } + else if (litk == BITVECTOR_ULT) + { + if (idx == 0) + { + if (pol) + { + /* x udiv s < t + * with side condition (synthesized): + * (and (bvult z s) (bvult z t)) + * where z = 0 with getSize(z) = w */ + Node u1 = nm->mkNode(BITVECTOR_ULT, z, s); + Node u2 = nm->mkNode(BITVECTOR_ULT, z, t); + scl = nm->mkNode(AND, u1, u2); + } + else + { + /* x udiv s >= t + * with side condition (synthesized): + * (= (bvand (bvudiv (bvmul s t) t) s) s) */ + Node mul = nm->mkNode(BITVECTOR_MULT, s, t); + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, t); + scl = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_AND, div, s), s); + } + } + else + { + if (pol) + { + /* s udiv x < t + * with side condition (synthesized): + * (and (bvult z (bvnot (bvand (bvneg t) s))) (bvult z t)) + * where z = 0 with getSize(z) = w */ + Node a = nm->mkNode(BITVECTOR_AND, nm->mkNode(BITVECTOR_NEG, t), s); + Node u1 = nm->mkNode(BITVECTOR_ULT, z, nm->mkNode(BITVECTOR_NOT, a)); + Node u2 = nm->mkNode(BITVECTOR_ULT, z, t); + scl = nm->mkNode(AND, u1, u2); + } + else + { + /* s udiv x >= t + * true (no side condition) */ + scl = nm->mkConst(true); + } + } + } + else if (litk == BITVECTOR_UGT) + { + if (idx == 0) + { + if (pol) + { + /* x udiv s > t * with side condition: - * s = t - * || - * t = ~0 - * || - * (s >= t - * && (s % t = 0 || (s / (t+1) +1) <= s / t) - * && (s = ~0 => t != 0)) */ - Node oo1 = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UREM_TOTAL, s, t), z); - Node udiv = nm->mkNode(BITVECTOR_UDIV_TOTAL, s, bv::utils::mkInc(t)); - Node ule1 = bv::utils::mkInc(udiv); - Node ule2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, s, t); - Node oo2 = nm->mkNode(BITVECTOR_ULE, ule1, ule2); - - Node a1 = nm->mkNode(BITVECTOR_UGE, s, t); - Node a2 = nm->mkNode(OR, oo1, oo2); - Node a3 = nm->mkNode(IMPLIES, s.eqNode(n), t.eqNode(z).notNode()); - - Node o1 = s.eqNode(t); - Node o2 = t.eqNode(n); - Node o3 = nm->mkNode(AND, a1, a2, a3); - - scl = nm->mkNode(OR, o1, o2, o3); - scr = nm->mkNode(EQUAL, nm->mkNode(k, s, x), t); - sc = nm->mkNode(IMPLIES, scl, scr); + * (bvugt (bvudiv ones s) t) + * with ones = ~0 */ + Node ones = bv::utils::mkOnes(w); + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s); + scl = nm->mkNode(BITVECTOR_UGT, div, t); + } + else + { + /* x udiv s <= t + * with side condition (synthesized): + * (not (bvult (bvor s t) (bvnot (bvneg s)))) */ + Node u1 = nm->mkNode(BITVECTOR_OR, s, t); + Node u2 = nm->mkNode(BITVECTOR_NOT, nm->mkNode(BITVECTOR_NEG, s)); + scl = nm->mkNode(BITVECTOR_UGE, u1, u2); + } + } + else + { + if (pol) + { + /* s udiv x > t + * with side condition (synthesized): + * (bvult t ones) + * with ones = ~0 */ + Node ones = bv::utils::mkOnes(w); + scl = nm->mkNode(BITVECTOR_ULT, t, ones); } else { - sc = nm->mkNode(DISTINCT, nm->mkNode(k, s, x), t); + /* s udiv x <= t + * with side condition (synthesized): + * (bvult z (bvor (bvnot s) t)) + * where z = 0 with getSize(z) = w */ + scl = nm->mkNode(BITVECTOR_ULT, + z, nm->mkNode(BITVECTOR_OR, nm->mkNode(BITVECTOR_NOT, s), t)); } } } - else + else if (litk == BITVECTOR_SLT) { - return Node::null(); + if (idx == 0) + { + if (pol) + { + /* x udiv s < t + * with side condition: + * (=> (bvsle t z) (bvslt (bvudiv min s) t)) + * where z = 0 with getSize(z) = w + * and min is the minimum signed value */ + BitVector bv_min = BitVector(w).setBit(w - 1); + Node min = bv::utils::mkConst(bv_min); + Node sle = nm->mkNode(BITVECTOR_SLE, t, z); + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, min, s); + Node slt = nm->mkNode(BITVECTOR_SLT, div, t); + scl = nm->mkNode(IMPLIES, sle, slt); + } + else + { + /* x udiv s >= t + * with side condition: + * (or + * (bvsge (bvudiv ones s) t) + * (bvsge (bvudiv max s) t)) + * with ones = ~0 and max the maximum signed value */ + BitVector bv_ones = utils::mkBitVectorOnes(w - 1); + BitVector bv_max = BitVector(1).concat(bv_ones); + Node max = bv::utils::mkConst(bv_max); + Node ones = bv::utils::mkOnes(w); + Node udiv1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s); + Node udiv2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, max, s); + Node sge1 = nm->mkNode(BITVECTOR_SGE, udiv1, t); + Node sge2 = nm->mkNode(BITVECTOR_SGE, udiv2, t); + scl = nm->mkNode(OR, sge1, sge2); + } + } + else + { + if (pol) + { + /* s udiv x < t + * with side condition (synthesized): + * (or (bvslt s t) (bvsge t z)) + * where z = 0 with getSize(z) = w */ + Node slt = nm->mkNode(BITVECTOR_SLT, s, t); + Node sge = nm->mkNode(BITVECTOR_SGE, t, z); + scl = nm->mkNode(OR, slt, sge); + } + else + { + /* s udiv x >= t + * with side condition: + * (and + * (=> (bvsge s z) (bvsge s t)) + * (=> (bvslt s z) (bvsge (bvudiv s (_ bv2 w)) t))) + * where z = 0 with getSize(z) = w */ + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, + s, bv::utils::mkConst(w, 2)); + Node i1 = nm->mkNode(IMPLIES, + nm->mkNode(BITVECTOR_SGE, s, z), nm->mkNode(BITVECTOR_SGE, s, t)); + Node i2 = nm->mkNode(IMPLIES, + nm->mkNode(BITVECTOR_SLT, s, z), nm->mkNode(BITVECTOR_SGE, div, t)); + scl = nm->mkNode(AND, i1, i2); + } + } + } + else /* litk == BITVECTOR_SGT */ + { + if (idx == 0) + { + if (pol) + { + /* x udiv s > t + * with side condition: + * (or + * (bvsgt (bvudiv ones s) t) + * (bvsgt (bvudiv max s) t)) + * with ones = ~0 and max the maximum signed value */ + BitVector bv_ones = utils::mkBitVectorOnes(w - 1); + BitVector bv_max = BitVector(1).concat(bv_ones); + Node max = bv::utils::mkConst(bv_max); + Node ones = bv::utils::mkOnes(w); + Node div1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, ones, s); + Node sgt1 = nm->mkNode(BITVECTOR_SGT, div1, t); + Node div2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, max, s); + Node sgt2 = nm->mkNode(BITVECTOR_SGT, div2, t); + scl = nm->mkNode(OR, sgt1, sgt2); + } + else + { + /* x udiv s <= t + * with side condition (combination of = and <): + * (or + * (= (bvudiv (bvmul s t) s) t) ; eq, synthesized + * (=> (bvsle t z) (bvslt (bvudiv min s) t))) ; slt + * where z = 0 with getSize(z) = w */ + Node mul = nm->mkNode(BITVECTOR_MULT, s, t); + Node div1 = nm->mkNode(BITVECTOR_UDIV_TOTAL, mul, s); + Node o1 = nm->mkNode(EQUAL, div1, t); + BitVector bv_min = BitVector(w).setBit(w - 1); + Node min = bv::utils::mkConst(bv_min); + Node sle = nm->mkNode(BITVECTOR_SLE, t, z); + Node div2 = nm->mkNode(BITVECTOR_UDIV_TOTAL, min, s); + Node slt = nm->mkNode(BITVECTOR_SLT, div2, t); + Node o2 = nm->mkNode(IMPLIES, sle, slt); + scl = nm->mkNode(OR, o1, o2); + } + } + else + { + if (pol) + { + /* s udiv x > t + * with side condition: + * (and + * (=> (bvsge s z) (bvsgt s t)) + * (=> (bvslt s z) (bvsgt (bvudiv s (_ bv2 w)) t))) + * where z = 0 with getSize(z) = w */ + Node div = nm->mkNode(BITVECTOR_UDIV_TOTAL, + s, bv::utils::mkConst(w, 2)); + Node i1 = nm->mkNode(IMPLIES, + nm->mkNode(BITVECTOR_SGE, s, z), nm->mkNode(BITVECTOR_SGT, s, t)); + Node i2 = nm->mkNode(IMPLIES, + nm->mkNode(BITVECTOR_SLT, s, z), nm->mkNode(BITVECTOR_SGT, div, t)); + scl = nm->mkNode(AND, i1, i2); + } + else + { + /* s udiv x <= t + * with side condition (synthesized): + * (not (and (bvslt t (bvnot #x0)) (bvslt t s))) + * <-> + * (or (bvsge t ones) (bvsge t s)) + * with ones = ~0 */ + Node ones = bv::utils::mkOnes(w); + Node sge1 = nm->mkNode(BITVECTOR_SGE, t, ones); + Node sge2 = nm->mkNode(BITVECTOR_SGE, t, s); + scl = nm->mkNode(OR, sge1, sge2); + } + } } + Node scr = + nm->mkNode(litk, idx == 0 ? nm->mkNode(k, x, s) : nm->mkNode(k, s, x), t); + Node sc = nm->mkNode(IMPLIES, scl, pol ? scr : scr.notNode()); Trace("bv-invert") << "Add SC_" << k << "(" << x << "): " << sc << std::endl; return sc; } diff --git a/test/unit/theory/theory_quantifiers_bv_inverter_white.h b/test/unit/theory/theory_quantifiers_bv_inverter_white.h index dc7164e54..291e2252d 100644 --- a/test/unit/theory/theory_quantifiers_bv_inverter_white.h +++ b/test/unit/theory/theory_quantifiers_bv_inverter_white.h @@ -101,7 +101,6 @@ class TheoryQuantifiersBvInverter : public CxxTest::TestSuite || k == BITVECTOR_OR || k == BITVECTOR_LSHR || k == BITVECTOR_ASHR - || k == BITVECTOR_SHL); Node sc = getsc(pol, litk, k, idx, d_sk, d_s, d_t); -- 2.30.2