From fc0a5dcc002b12f075681d53e87cca1ddfbd479d Mon Sep 17 00:00:00 2001 From: Aina Niemetz Date: Thu, 19 Oct 2017 12:31:42 -0700 Subject: [PATCH] CBQI BV: Refactor solve_bv_constraint. (#1265) This refactors function solve_bv_constraint to use a switch-case over kinds rather than an if-else chain. --- src/theory/quantifiers/bv_inverter.cpp | 438 +++++++++++++------------ 1 file changed, 232 insertions(+), 206 deletions(-) diff --git a/src/theory/quantifiers/bv_inverter.cpp b/src/theory/quantifiers/bv_inverter.cpp index 8a65338a6..ad1259be0 100644 --- a/src/theory/quantifiers/bv_inverter.cpp +++ b/src/theory/quantifiers/bv_inverter.cpp @@ -261,19 +261,25 @@ Node BvInverter::getPathToPv(Node lit, Node pv, Node sv, Node pvs, return slit; } -Node BvInverter::solve_bv_constraint(Node sv, Node sv_t, Node t, Kind rk, - bool pol, std::vector& path, +Node BvInverter::solve_bv_constraint(Node sv, + Node sv_t, + Node t, + Kind rk, + bool pol, + std::vector& path, BvInverterModelQuery* m, BvInverterStatus& status) { + unsigned index; + unsigned nchildren; NodeManager* nm = NodeManager::currentNM(); + while (!path.empty()) { - unsigned index = path.back(); + index = path.back(); Assert(index < sv_t.getNumChildren()); path.pop_back(); Kind k = sv_t.getKind(); - unsigned nchildren = sv_t.getNumChildren(); + nchildren = sv_t.getNumChildren(); - /* inversions */ if (k == BITVECTOR_CONCAT) { /* x = t[upper:lower] * where @@ -302,219 +308,239 @@ Node BvInverter::solve_bv_constraint(Node sv, Node sv_t, Node t, Kind rk, Node s = nchildren == 2 ? sv_t[1 - index] : dropChild(sv_t, index); /* Note: All n-ary kinds except for CONCAT (i.e., AND, OR, MULT, PLUS) * are commutative (no case split based on index). */ - if (k == BITVECTOR_PLUS) { - t = nm->mkNode(BITVECTOR_SUB, t, s); - } else if (k == BITVECTOR_SUB) { - t = nm->mkNode(BITVECTOR_PLUS, t, s); - } else if (k == BITVECTOR_MULT) { - /* t = skv (fresh skolem constant) - * with side condition: - * ctz(t) >= ctz(s) <-> x * s = t - * where - * ctz(t) >= ctz(s) -> (t & -t) >= (s & -s) */ - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - /* left hand side of side condition */ - Node scl = nm->mkNode( - BITVECTOR_UGE, - nm->mkNode(BITVECTOR_AND, t, nm->mkNode(BITVECTOR_NEG, t)), - nm->mkNode(BITVECTOR_AND, s, nm->mkNode(BITVECTOR_NEG, s))); - /* right hand side of side condition */ - Node scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_MULT, x, s), t); - /* overall side condition */ - Node sc = nm->mkNode(IMPLIES, scl, scr); - /* add side condition */ - status.d_conds.push_back(sc); - - /* get the skolem node for this side condition */ - Node skv = getInversionNode(sc, solve_tn); - /* now solving with the skolem node as the RHS */ - t = skv; - } else if (k == BITVECTOR_UREM_TOTAL) { - /* t = skv (fresh skolem constant) */ - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - Node scl, scr; - if (index == 0) { - /* x % s = t is rewritten to x - x / y * y */ - Trace("bv-invert") << "bv-invert : Unsupported for index " << index - << ", from " << sv_t << std::endl; - return Node::null(); - } else { - /* s % x = t - * with side conditions: - * s > t - * && s-t > t - * && (t = 0 || t != s-1) */ - Node s_gt_t = nm->mkNode(BITVECTOR_UGT, s, t); - Node s_m_t = nm->mkNode(BITVECTOR_SUB, s, t); - Node smt_gt_t = nm->mkNode(BITVECTOR_UGT, s_m_t, t); - Node t_eq_z = nm->mkNode(EQUAL, - t, bv::utils::mkZero(bv::utils::getSize(t))); - Node s_m_o = nm->mkNode(BITVECTOR_SUB, - s, bv::utils::mkOne(bv::utils::getSize(s))); - Node t_d_smo = nm->mkNode(DISTINCT, t, s_m_o); - - scl = nm->mkNode(AND, - nm->mkNode(AND, s_gt_t, smt_gt_t), - nm->mkNode(OR, t_eq_z, t_d_smo)); - scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UREM_TOTAL, s, x), t); + switch(k) { + case BITVECTOR_PLUS: + t = nm->mkNode(BITVECTOR_SUB, t, s); + break; + case BITVECTOR_SUB: + t = nm->mkNode(BITVECTOR_PLUS, t, s); + break; + + case BITVECTOR_MULT: { + /* t = skv (fresh skolem constant) + * with side condition: + * ctz(t) >= ctz(s) <-> x * s = t + * where + * ctz(t) >= ctz(s) -> (t & -t) >= (s & -s) */ + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + /* left hand side of side condition */ + Node scl = nm->mkNode( + BITVECTOR_UGE, + nm->mkNode(BITVECTOR_AND, t, nm->mkNode(BITVECTOR_NEG, t)), + nm->mkNode(BITVECTOR_AND, s, nm->mkNode(BITVECTOR_NEG, s))); + /* right hand side of side condition */ + Node scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_MULT, x, s), t); + /* overall side condition */ + Node sc = nm->mkNode(IMPLIES, scl, scr); + /* add side condition */ + status.d_conds.push_back(sc); + + /* get the skolem node for this side condition */ + Node skv = getInversionNode(sc, solve_tn); + /* now solving with the skolem node as the RHS */ + t = skv; + break; } - Node sc = nm->mkNode(IMPLIES, scl, scr); - status.d_conds.push_back(sc); - Node skv = getInversionNode(sc, solve_tn); - t = skv; - } else if (k == BITVECTOR_AND || k == BITVECTOR_OR) { - /* t = skv (fresh skolem constant) - * with side condition: - * t & s = t - * t | s = t */ - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - Node scl = nm->mkNode(EQUAL, t, nm->mkNode(k, t, s)); - Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); - Node sc = nm->mkNode(IMPLIES, scl, scr); - status.d_conds.push_back(sc); - Node skv = getInversionNode(sc, solve_tn); - t = skv; - } else if (k == BITVECTOR_LSHR) { - /* t = skv (fresh skolem constant) */ - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - Node scl, scr; - if (index == 0) { - /* x >> s = t + + case BITVECTOR_UREM_TOTAL: { + /* t = skv (fresh skolem constant) */ + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node scl, scr; + if (index == 0) { + /* x % s = t is rewritten to x - x / y * y */ + Trace("bv-invert") << "bv-invert : Unsupported for index " << index + << ", from " << sv_t << std::endl; + return Node::null(); + } else { + /* s % x = t + * with side conditions: + * s > t + * && s-t > t + * && (t = 0 || t != s-1) */ + Node s_gt_t = nm->mkNode(BITVECTOR_UGT, s, t); + Node s_m_t = nm->mkNode(BITVECTOR_SUB, s, t); + Node smt_gt_t = nm->mkNode(BITVECTOR_UGT, s_m_t, t); + Node t_eq_z = nm->mkNode(EQUAL, + t, bv::utils::mkZero(bv::utils::getSize(t))); + Node s_m_o = nm->mkNode(BITVECTOR_SUB, + s, bv::utils::mkOne(bv::utils::getSize(s))); + Node t_d_smo = nm->mkNode(DISTINCT, t, s_m_o); + + scl = nm->mkNode(AND, + nm->mkNode(AND, s_gt_t, smt_gt_t), + nm->mkNode(OR, t_eq_z, t_d_smo)); + scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UREM_TOTAL, s, x), t); + } + Node sc = nm->mkNode(IMPLIES, scl, scr); + status.d_conds.push_back(sc); + Node skv = getInversionNode(sc, solve_tn); + t = skv; + break; + } + + case BITVECTOR_AND: + case BITVECTOR_OR: { + /* t = skv (fresh skolem constant) * with side condition: - * s = 0 || clz(t) >= s - * -> - * s = 0 || ((z o t) << s)[2w-1 : w] = z - * with w = getSize(t) = getSize(s) and z = 0 with getSize(z) = w */ - unsigned w = bv::utils::getSize(s); - Node z = bv::utils::mkZero(w); - Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t); - Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s); - Node zot_shl_zos = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); - Node ext = bv::utils::mkExtract(zot_shl_zos, 2*w-1, w); - scl = nm->mkNode(OR, - nm->mkNode(EQUAL, s, z), - nm->mkNode(EQUAL, ext, z)); - scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_LSHR, x, s), t); + * t & s = t + * t | s = t */ + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node scl = nm->mkNode(EQUAL, t, nm->mkNode(k, t, s)); + Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t); Node sc = nm->mkNode(IMPLIES, scl, scr); status.d_conds.push_back(sc); Node skv = getInversionNode(sc, solve_tn); t = skv; - } else { - // TODO: index == 1 - /* s >> x = t - * with side conditions: - * (s = 0 && t = 0) - * || (clz(t) >= clz(s) - * && (t = 0 - * || "remaining shifted bits in t " - * "match corresponding bits in s")) */ - Trace("bv-invert") << "bv-invert : Unsupported for index " << index - << ", from " << sv_t << std::endl; - return Node::null(); + break; } - } else if (k == BITVECTOR_UDIV_TOTAL) { - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - Node s = sv_t[1 - index]; - unsigned w = bv::utils::getSize(s); - Node scl, scr; - Node zero = bv::utils::mkConst(w, 0u); - - /* x udiv s = t */ - if (index == 0) { - /* with side conditions: - * !umulo(s * t) - */ - scl = nm->mkNode(NOT, bv::utils::mkUmulo(s, t)); - scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, x, s), t); - /* s udiv x = t */ - } else { - /* with side conditions: - * (t = 0 && (s = 0 || s != 2^w-1)) - * || s >= t - * || t = 2^w-1 - */ - Node ones = bv::utils::mkOnes(w); - Node t_eq_zero = nm->mkNode(EQUAL, t, zero); - Node s_eq_zero = nm->mkNode(EQUAL, s, zero); - Node s_ne_ones = nm->mkNode(DISTINCT, s, ones); - Node s_ge_t = nm->mkNode(BITVECTOR_UGE, s, t); - Node t_eq_ones = nm->mkNode(EQUAL, t, ones); - scl = nm->mkNode( - OR, - nm->mkNode(AND, t_eq_zero, nm->mkNode(OR, s_eq_zero, s_ne_ones)), - s_ge_t, t_eq_ones); - scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, s, x), t); + + case BITVECTOR_LSHR: { + /* t = skv (fresh skolem constant) */ + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node scl, scr; + if (index == 0) { + /* x >> s = t + * with side condition: + * s = 0 || clz(t) >= s + * -> + * s = 0 || ((z o t) << s)[2w-1 : w] = z + * with w = getSize(t) = getSize(s) + * and z = 0 with getSize(z) = w */ + unsigned w = bv::utils::getSize(s); + Node z = bv::utils::mkZero(w); + Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t); + Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s); + Node zot_shl_zos = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s); + Node ext = bv::utils::mkExtract(zot_shl_zos, 2*w-1, w); + scl = nm->mkNode(OR, + nm->mkNode(EQUAL, s, z), + nm->mkNode(EQUAL, ext, z)); + scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_LSHR, x, s), t); + Node sc = nm->mkNode(IMPLIES, scl, scr); + status.d_conds.push_back(sc); + Node skv = getInversionNode(sc, solve_tn); + t = skv; + } else { + /* s >> x = t + * with side conditions: + * (s = 0 && t = 0) + * || (clz(t) >= clz(s) + * && (t = 0 + * || "remaining shifted bits in t " + * "match corresponding bits in s")) */ + Trace("bv-invert") << "bv-invert : Unsupported for index " << index + << ", from " << sv_t << std::endl; + return Node::null(); + } + break; } - /* overall side condition */ - Node sc = nm->mkNode(IMPLIES, scl, scr); - /* add side condition */ - status.d_conds.push_back(sc); - - /* get the skolem node for this side condition*/ - Node skv = getInversionNode(sc, solve_tn); - /* now solving with the skolem node as the RHS */ - t = skv; - } else if (k == BITVECTOR_SHL) { - TypeNode solve_tn = sv_t[index].getType(); - Node x = getSolveVariable(solve_tn); - Node s = sv_t[1 - index]; - unsigned w = bv::utils::getSize(s); - Node scl, scr; - - /* x << s = t */ - if (index == 0) { - /* with side conditions: - * (s = 0 || ctz(t) >= s) - * <-> - * (s = 0 || ((t o z) >> (z o s))[w-1:0] = z) - * - * where - * w = getSize(s) = getSize(t) = getSize (z) && z = 0 - */ + case BITVECTOR_UDIV_TOTAL: { + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node s = sv_t[1 - index]; + unsigned w = bv::utils::getSize(s); + Node scl, scr; Node zero = bv::utils::mkConst(w, 0u); - Node s_eq_zero = nm->mkNode(EQUAL, s, zero); - Node t_conc_zero = nm->mkNode(BITVECTOR_CONCAT, t, zero); - Node zero_conc_s = nm->mkNode(BITVECTOR_CONCAT, zero, s); - Node shr_s = nm->mkNode(BITVECTOR_LSHR, t_conc_zero, zero_conc_s); - Node extr_shr_s = bv::utils::mkExtract(shr_s, w - 1, 0); - Node ctz_t_ge_s = nm->mkNode(EQUAL, extr_shr_s, zero); - scl = nm->mkNode(OR, s_eq_zero, ctz_t_ge_s); - scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_SHL, x, s), t); - /* s << x = t */ - } else { - /* with side conditions: - * (s = 0 && t = 0) - * || (ctz(t) >= ctz(s) - * && (t = 0 || - * "remaining shifted bits in t match corresponding bits in s")) - */ - Trace("bv-invert") << "bv-invert : Unsupported for index " << index - << ", from " << sv_t << std::endl; - return Node::null(); + + if (index == 0) { + /* x udiv s = t + * with side conditions: + * !umulo(s * t) + */ + scl = nm->mkNode(NOT, bv::utils::mkUmulo(s, t)); + scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, x, s), t); + } else { + /* s udiv x = t + * with side conditions: + * (t = 0 && (s = 0 || s != 2^w-1)) + * || s >= t + * || t = 2^w-1 + */ + Node ones = bv::utils::mkOnes(w); + Node t_eq_zero = nm->mkNode(EQUAL, t, zero); + Node s_eq_zero = nm->mkNode(EQUAL, s, zero); + Node s_ne_ones = nm->mkNode(DISTINCT, s, ones); + Node s_ge_t = nm->mkNode(BITVECTOR_UGE, s, t); + Node t_eq_ones = nm->mkNode(EQUAL, t, ones); + scl = nm->mkNode(OR, + nm->mkNode(AND, t_eq_zero, + nm->mkNode(OR, s_eq_zero, s_ne_ones)), + s_ge_t, t_eq_ones); + scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, s, x), t); + } + + /* overall side condition */ + Node sc = nm->mkNode(IMPLIES, scl, scr); + /* add side condition */ + status.d_conds.push_back(sc); + + /* get the skolem node for this side condition*/ + Node skv = getInversionNode(sc, solve_tn); + /* now solving with the skolem node as the RHS */ + t = skv; + break; } - /* overall side condition */ - Node sc = nm->mkNode(IMPLIES, scl, scr); - /* add side condition */ - status.d_conds.push_back(sc); - - /* get the skolem node for this side condition*/ - Node skv = getInversionNode(sc, solve_tn); - /* now solving with the skolem node as the RHS */ - t = skv; - //}else if( k==BITVECTOR_ASHR ){ - // TODO - } else { - Trace("bv-invert") << "bv-invert : Unknown kind for bit-vector term " - << k - << ", from " << sv_t << std::endl; - return Node::null(); + case BITVECTOR_SHL: { + TypeNode solve_tn = sv_t[index].getType(); + Node x = getSolveVariable(solve_tn); + Node s = sv_t[1 - index]; + unsigned w = bv::utils::getSize(s); + Node scl, scr; + + if (index == 0) { + /* x << s = t + * with side conditions: + * (s = 0 || ctz(t) >= s) + * <-> + * (s = 0 || ((t o z) >> (z o s))[w-1:0] = z) + * + * where + * w = getSize(s) = getSize(t) = getSize (z) && z = 0 + */ + Node zero = bv::utils::mkConst(w, 0u); + Node s_eq_zero = nm->mkNode(EQUAL, s, zero); + Node t_conc_zero = nm->mkNode(BITVECTOR_CONCAT, t, zero); + Node zero_conc_s = nm->mkNode(BITVECTOR_CONCAT, zero, s); + Node shr_s = nm->mkNode(BITVECTOR_LSHR, t_conc_zero, zero_conc_s); + Node extr_shr_s = bv::utils::mkExtract(shr_s, w - 1, 0); + Node ctz_t_ge_s = nm->mkNode(EQUAL, extr_shr_s, zero); + scl = nm->mkNode(OR, s_eq_zero, ctz_t_ge_s); + scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_SHL, x, s), t); + } else { + /* s << x = t + * with side conditions: + * (s = 0 && t = 0) + * || (ctz(t) >= ctz(s) + * && (t = 0 || + * "remaining shifted bits in t" + * "match corresponding bits in s")) + */ + Trace("bv-invert") << "bv-invert : Unsupported for index " << index + << "for bit-vector term " << sv_t << std::endl; + return Node::null(); + } + + /* overall side condition */ + Node sc = nm->mkNode(IMPLIES, scl, scr); + /* add side condition */ + status.d_conds.push_back(sc); + + /* get the skolem node for this side condition*/ + Node skv = getInversionNode(sc, solve_tn); + /* now solving with the skolem node as the RHS */ + t = skv; + break; + } + default: + Trace("bv-invert") << "bv-invert : Unknown kind " << k + << " for bit-vector term " << sv_t << std::endl; + return Node::null(); } } sv_t = sv_t[index]; -- 2.30.2