From: Mathias Preiner Date: Sat, 6 Jan 2018 02:05:41 +0000 (-0800) Subject: Use simpler EQUAL SCs for LSHR0, LSHR1, ASHR0, AHSR1, SHL0, SHL1. (#1482) X-Git-Tag: cvc5-1.0.0~5380 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=947b04aafc5d45fe901eccebc151f7a9c90de00e;p=cvc5.git Use simpler EQUAL SCs for LSHR0, LSHR1, ASHR0, AHSR1, SHL0, SHL1. (#1482) --- diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index 7f2343df7..7407051be 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -229,6 +229,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) else { /* x >= t + * with side condition: * true (no side condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } @@ -250,6 +251,7 @@ static Node getScBvUltUgt(bool pol, Kind k, Node x, Node t) else { /* x <= t + * with side condition: * true (no side condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } @@ -283,6 +285,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) else { /* x >= t + * with side condition: * true (no side condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } @@ -306,6 +309,7 @@ static Node getScBvSltSgt(bool pol, Kind k, Node x, Node t) else { /* x <= t + * with side condition: * true (no side condition) */ sc = nm->mkNode(NOT, nm->mkNode(k, x, t)); } @@ -1343,6 +1347,29 @@ static Node getScBvAndOr(bool pol, return sc; } +static Node defaultShiftSc(Kind litk, Kind shk, Node s, Node t) +{ + unsigned w; + NodeBuilder<> nb(OR); + NodeManager *nm; + + nm = NodeManager::currentNM(); + + w = bv::utils::getSize(s); + Assert(w == bv::utils::getSize(t)); + + nb << nm->mkNode(litk, s, t); + for (unsigned i = 1; i <= w; i++) + { + Node sw = bv::utils::mkConst(w, i); + nb << nm->mkNode(litk, nm->mkNode(shk, s, sw), t); + } + if (nb.getNumChildren() == 1) + return nb[0]; + return nb.constructNode(); +} + + static Node getScBvLshr(bool pol, Kind litk, Kind k, @@ -1370,29 +1397,11 @@ static Node getScBvLshr(bool pol, if (pol) { /* x >> s = t - * with side condition: - * s = 0 || (s < w && clz(t) >=s) || (s >= w && t = 0) - * -> - * (or - * (= s z) - * (and - * (bvult s w) - * (= ((_ extract 2*w-1 w) (bvshl (concat z t) (concat z s))) z)) - * (and (bvuge s w) (= t z))) - * with w = getSize(t) = getSize(s) - * and z = 0 with getSize(z) = w */ - Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t); - Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s); - Node shl = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); - Node ext = bv::utils::mkExtract(shl, 2*w-1, w); - - Node o1 = s.eqNode(z); - Node o2 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); - Node o3 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); - - scl = nm->mkNode(OR, o1, o2, o3); + * with side condition (synthesized): + * (= (bvlshr (bvshl t s) s) t) */ + Node shl = nm->mkNode(BITVECTOR_SHL, t, s); + Node lshr = nm->mkNode(BITVECTOR_LSHR, shl, s); + scl = lshr.eqNode(t); } else { @@ -1413,30 +1422,9 @@ static Node getScBvLshr(bool pol, { /* s >> x = t * with side condition: - * (or - * (= t z) - * (= s t) - * (and - * (= ((_ extract w-1-i 0) t) ((_ extract w-1 i) s)) - * (= ((_ extract w-1 w-i) t) z_i))) - * for 0 < i < w - * where - * w = getSize(s) = getSize(t) - * and z_i = 0 with getSize(z_i) = i */ - NodeBuilder<> nb(nm, OR); - nb << nm->mkNode(EQUAL, t, s); - for (unsigned i = 1; i < w; ++i) - { - nb << nm->mkNode(AND, - nm->mkNode(EQUAL, - bv::utils::mkExtract(t, w - 1 - i, 0), - bv::utils::mkExtract(s, w - 1, i)), - nm->mkNode(EQUAL, - bv::utils::mkExtract(t, w - 1, w - i), - bv::utils::mkZero(i))); - } - nb << t.eqNode(z); - scl = nb.constructNode(); + * (or (= (bvlshr s i) t) ...) + * for i in 0..w */ + scl = defaultShiftSc(EQUAL, BITVECTOR_LSHR, s, t); } else { @@ -1681,41 +1669,23 @@ static Node getScBvAshr(bool pol, { /* x >> s = t * with side condition: - * (or - * (= s z) - * (and - * (bvult s w) - * (or - * (= ((_ extract 2*w-1 w-1) (bvshl (concat z t) (concat z s))) z) - * (= ((_ extract 2*w-1 w-1) - * (bvshl (concat ones t) (concat z s))) ones))) - * (and (bvuge s w) (or (= t z) (= t ones)))) + * (and + * (=> (bvult s w) (= (bvashr (bvshl t s) s) t)) + * (=> (bvuge s w) (or (= t ones) (= t z))) + * ) * where * z = 0 with getSize(z) = w * and ones = ~0 with getSize(ones) = w * and w = getSize(t) = getSize(s) */ - Node zz = bv::utils::mkZero(w+1); - Node nn = bv::utils::mkOnes(w+1); Node ww = bv::utils::mkConst(w, w); - - Node z_o_t = bv::utils::mkConcat(z, t); - Node z_o_s = bv::utils::mkConcat(z, s); - Node n_o_t = bv::utils::mkConcat(n, t); - - Node shlz = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); - Node shln = nm->mkNode(BITVECTOR_SHL, n_o_t, z_o_s); - Node extz = bv::utils::mkExtract(shlz, 2*w-1, w-1); - Node extn = bv::utils::mkExtract(shln, 2*w-1, w-1); - - Node o1 = s.eqNode(z); - Node o2 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_ULT, s, ww), - nm->mkNode(OR, extz.eqNode(zz), extn.eqNode(nn))); - Node o3 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_UGE, s, ww), - nm->mkNode(OR, t.eqNode(z), t.eqNode(n))); - - scl = nm->mkNode(OR, o1, o2, o3); + Node shl = nm->mkNode(BITVECTOR_SHL, t, s); + Node ashr = nm->mkNode(BITVECTOR_ASHR, shl, s); + Node ult = nm->mkNode(BITVECTOR_ULT, s, ww); + Node imp1 = ult.impNode(ashr.eqNode(t)); + Node to = t.eqNode(n); + Node tz = t.eqNode(z); + Node imp2 = ult.notNode().impNode(to.orNode(tz)); + scl = imp1.andNode(imp2); } else { @@ -1730,48 +1700,9 @@ static Node getScBvAshr(bool pol, { /* s >> x = t * with side condition: - * (or - * (and (= ((_ extract w-1 w-1) s) (_ bv0 1)) (= t z)) - * (and (= ((_ extract w-1 w-1) s) (_ bv0 1)) (= t ones)) - * (= s t) - * (and - * (= ((_ extract w-1-i 0) t) ((_ extract w-1 i) s)) - * (or - * (and - * (= ((_ extract w-1 w-1) s) (_ bv0 1)) - * (= ((_ extract w-1 w-i) t) (_ bv0 i))) - * (and - * (= ((_ extract w-1 w-1) s) (_ bv1 1)) - * (= ((_ extract w-1 w-i) t) ones_i))))) - * for 0 < i < w - * where - * z = 0 and getSize(z) = w - * and ones = ~0 and getSize(ones) = w - * and ones_i = ~0 and getSize(ones_i) = i - * and w = getSize(s) = getSize(t) */ - Node msbz = bv::utils::mkExtract( - s, w-1, w-1).eqNode(bv::utils::mkZero(1)); - Node msbn = bv::utils::mkExtract( - s, w-1, w-1).eqNode(bv::utils::mkOnes(1)); - NodeBuilder<> nb(nm, OR); - nb << nm->mkNode(EQUAL, t, s); - for (unsigned i = 1; i < w; ++i) - { - Node ext = bv::utils::mkExtract(t, w-1, w-i); - - Node o1 = nm->mkNode(AND, msbz, ext.eqNode(bv::utils::mkZero(i))); - Node o2 = nm->mkNode(AND, msbn, ext.eqNode(bv::utils::mkOnes(i))); - Node o = nm->mkNode(OR, o1, o2); - - Node e = nm->mkNode(EQUAL, - bv::utils::mkExtract(t, w - 1 - i, 0), - bv::utils::mkExtract(s, w - 1, i)); - - nb << nm->mkNode(AND, e, o); - } - nb << nm->mkNode(AND, msbz, t.eqNode(z)); - nb << nm->mkNode(AND, msbn, t.eqNode(n)); - scl = nb.constructNode(); + * (or (= (bvashr s i) t) ...) + * for i in 0..w */ + scl = defaultShiftSc(EQUAL, BITVECTOR_ASHR, s, t); } else { @@ -1998,28 +1929,6 @@ static Node getScBvAshr(bool pol, return sc; } -static Node naiveShlSc1(Kind litk, Node s, Node t) -{ - unsigned w; - NodeBuilder<> nb(OR); - NodeManager *nm; - - nm = NodeManager::currentNM(); - - w = bv::utils::getSize(s); - Assert(w == bv::utils::getSize(t)); - - nb << nm->mkNode(litk, s, t); - for (unsigned i = 1; i <= w; i++) - { - Node sw = bv::utils::mkConst(w, i); - nb << nm->mkNode(litk, nm->mkNode(BITVECTOR_SHL, s, sw), t); - } - if (nb.getNumChildren() == 1) - return nb[0]; - return nb.constructNode(); -} - static Node getScBvShl(bool pol, Kind litk, Kind k, @@ -2047,30 +1956,11 @@ static Node getScBvShl(bool pol, if (pol) { /* x << s = t - * with side condition: - * (s = 0 || ctz(t) >= s) - * <-> - * (or - * (= s z) - * (and - * (bvult s w) - * (= ((_ extract w-1 0) (bvlshr (concat t z) (concat z s))) z)) - * (and (bvuge s w) (= t z))) - * - * where - * w = getSize(s) = getSize(t) = getSize(z) and z = 0 */ - Node shr = nm->mkNode(BITVECTOR_LSHR, - bv::utils::mkConcat(t, z), - bv::utils::mkConcat(z, s)); - Node ext = bv::utils::mkExtract(shr, w - 1, 0); - - Node o1 = nm->mkNode(EQUAL, s, z); - Node o2 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_ULT, s, ww), ext.eqNode(z)); - Node o3 = nm->mkNode(AND, - nm->mkNode(BITVECTOR_UGE, s, ww), t.eqNode(z)); - - scl = nm->mkNode(OR, o1, o2, o3); + * with side condition (synthesized): + * (= (bvshl (bvlshr t s) s) t) */ + Node lshr = nm->mkNode(BITVECTOR_LSHR, t, s); + Node shl = nm->mkNode(BITVECTOR_SHL, lshr, s); + scl = shl.eqNode(t); } else { @@ -2091,29 +1981,9 @@ static Node getScBvShl(bool pol, { /* s << x = t * with side condition: - * (or - * (= t z) - * (= s z) - * (and - * (= ((_ extract w-1 i) t) ((_ extract w-1-i 0))) - * (= ((_ extract i-1 0) t) z_i))) - * for 0 < i < w - * where - * w = getSize(s) = getSize(t) - * and z = 0 with getSize(z) = w - * and z_i = 0 with getSize(z_i) = i */ - NodeBuilder<> nb(nm, OR); - nb << nm->mkNode(EQUAL, t, s); - for (unsigned i = 1; i < w; ++i) - { - nb << nm->mkNode(AND, - nm->mkNode(EQUAL, - bv::utils::mkExtract(t, w-1, i), bv::utils::mkExtract(s, w-1-i, 0)), - nm->mkNode(EQUAL, - bv::utils::mkExtract(t, i-1, 0), bv::utils::mkZero(i))); - } - nb << t.eqNode(z); - scl = nb.constructNode(); + * (or (= (bvshl s i) t) ...) + * for i in 0..w */ + scl = defaultShiftSc(EQUAL, BITVECTOR_SHL, s, t); } else { @@ -2160,8 +2030,8 @@ static Node getScBvShl(bool pol, /* s << x >= t * with side condition: * (or (bvuge (bvshl s i) t) ...) - * for i in 0..w-1 */ - scl = naiveShlSc1(BITVECTOR_UGE, s, t); + * for i in 0..w */ + scl = defaultShiftSc(BITVECTOR_UGE, BITVECTOR_SHL, s, t); } } } @@ -2194,8 +2064,8 @@ static Node getScBvShl(bool pol, /* s << x > t * with side condition: * (or (bvugt (bvshl s i) t) ...) - * for i in 0..w-1 */ - scl = naiveShlSc1(BITVECTOR_UGT, s, t); + * for i in 0..w */ + scl = defaultShiftSc(BITVECTOR_UGT, BITVECTOR_SHL, s, t); } else { @@ -2257,8 +2127,8 @@ static Node getScBvShl(bool pol, /* s << x >= t * with side condition: * (or (bvsge (bvshl s i) t) ...) - * for i in 0..w-1 */ - scl = naiveShlSc1(BITVECTOR_SGE, s, t); + * for i in 0..w */ + scl = defaultShiftSc(BITVECTOR_SGE, BITVECTOR_SHL, s, t); } } } @@ -2300,8 +2170,8 @@ static Node getScBvShl(bool pol, /* s << x > t * with side condition: * (or (bvsgt (bvshl s i) t) ...) - * for i in 0..w-1 */ - scl = naiveShlSc1(BITVECTOR_SGT, s, t); + * for i in 0..w */ + scl = defaultShiftSc(BITVECTOR_SGT, BITVECTOR_SHL, s, t); } else {