From 13cc0e94ac8892fa1cefa53ff1c884d154894b58 Mon Sep 17 00:00:00 2001 From: Mathias Preiner Date: Wed, 20 Dec 2017 16:45:07 -0800 Subject: [PATCH] Add rewriting rule for ranking benchmarks. (#1448) --- src/theory/bv/theory_bv_rewrite_rules.h | 5 +- .../theory_bv_rewrite_rules_simplification.h | 146 +++++++++++++++++- src/theory/bv/theory_bv_rewriter.cpp | 3 +- src/theory/bv/theory_bv_utils.h | 5 + 4 files changed, 156 insertions(+), 3 deletions(-) diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index 4d0f8033e..1d2e7d9a3 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -179,7 +179,8 @@ enum RewriteRuleId { BBPlusNeg, UltPlusOne, ConcatToMult, - IsPowerOfTwo + IsPowerOfTwo, + MultSltMult, }; inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { @@ -319,6 +320,7 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case UltPlusOne: out << "UltPlusOne"; return out; case ConcatToMult: out << "ConcatToMult"; return out; case IsPowerOfTwo: out << "IsPowerOfTwo"; return out; + case MultSltMult: out << "MultSltMult"; return out; default: Unreachable(); } @@ -545,6 +547,7 @@ struct AllRewriteRules { RewriteRule rule125; RewriteRule rule126; RewriteRule rule127; + RewriteRule rule128; }; template<> inline diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index 24e5fb5e8..98a311890 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -1431,9 +1431,153 @@ Node RewriteRule::apply(TNode node) { return x_eq_sh; } +/** + * Rewrite + * sign_extend(x+t,n) * sign_extend(a,m) < sign_extend(x,n) * sign_extend(a,m) + * to + * (and + * (not (= t zero)) + * (not (= a zero)) + * (= (bvslt (bvadd x t) x) (bvsgt a zero)) + * ) + * + * Rewrite + * zero_extend(x+t,n) * sign_extend(a,m) < zero_extend(x,n) * sign_extend(a,m) + * to + * (and + * (not (= t zero)) + * (not (= a zero)) + * (= (bvult (bvadd x t) x) (bvsgt a zero)) + * ) + * where n and m are sufficiently big to not produce an overflow for + * the multiplications. + * + * These patterns occur in the quantified BV benchmark family 'ranking', + * where the BV engine struggles due to the high bit widths of the + * multiplication's operands. + */ +static std::tuple +extract_ext_tuple(TNode node) +{ + TNode a = node[0]; + TNode b = node[1]; + for (unsigned i = 0; i < 2; ++i) + { + if (a.getKind() == kind::BITVECTOR_CONCAT + && b.getKind() == kind::BITVECTOR_SIGN_EXTEND + && a[0] == utils::mkZero(utils::getSize(a[0])) + && utils::getSize(a[1]) <= utils::getSize(a[0]) + && utils::getSize(b[0]) <= utils::getSignExtendAmount(b)) + { + return std::make_tuple(a[1], b[0], false); + } + else if (i == 0 + && a.getKind() == kind::BITVECTOR_SIGN_EXTEND + && b.getKind() == kind::BITVECTOR_SIGN_EXTEND + && utils::getSize(a[0]) <= utils::getSignExtendAmount(a) + && utils::getSize(b[0]) <= utils::getSignExtendAmount(b)) + { + return std::make_tuple(a[0], b[0], true); + } + std::swap(a, b); + } + return std::make_tuple(Node::null(), Node::null(), false); +} +template<> inline +bool RewriteRule::applies(TNode node) +{ + if (node.getKind() != kind::BITVECTOR_SLT + || node[0].getKind() != kind::BITVECTOR_MULT + || node[1].getKind() != kind::BITVECTOR_MULT) + return false; + if (node[0].getNumChildren() > 2 || node[1].getNumChildren() > 2) + return false; + bool is_sext_l, is_sext_r; + TNode ml[2], mr[2]; + + std::tie(ml[0], ml[1], is_sext_l) = extract_ext_tuple(node[0]); + if (ml[0].isNull()) + return false; + + std::tie(mr[0], mr[1], is_sext_r) = extract_ext_tuple(node[1]); + if (mr[0].isNull()) + return false; + + if (is_sext_l != is_sext_r) + return false; + + TNode addxt, x, a; + if (ml[0].getKind() == kind::BITVECTOR_PLUS) + { + addxt = ml[0]; + a = ml[1]; + } + else if (ml[1].getKind() == kind::BITVECTOR_PLUS) + { + addxt = ml[1]; + a = ml[0]; + } + else + return false; + + if (addxt.getNumChildren() > 2) + return false; + + if (mr[0] == a) + { + x = mr[1]; + } + else if (mr[1] == a) + { + x = mr[0]; + } + else + return false; + + return (addxt[0] == x || addxt[1] == x); } + +template<> inline +Node RewriteRule::apply(TNode node) +{ + bool is_sext; + TNode ml[2], mr[2]; + + std::tie(ml[0], ml[1], is_sext) = extract_ext_tuple(node[0]); + std::tie(mr[0], mr[1], std::ignore) = extract_ext_tuple(node[1]); + + TNode addxt, x, t, a; + if (ml[0].getKind() == kind::BITVECTOR_PLUS) + { + addxt = ml[0]; + a = ml[1]; + } + else + { + Assert(ml[1].getKind() == kind::BITVECTOR_PLUS); + addxt = ml[1]; + a = ml[0]; + } + + x = (mr[0] == a) ? mr[1] : mr[0]; + t = (addxt[0] == x) ? addxt[1] : addxt[0]; + + NodeManager *nm = NodeManager::currentNM(); + Node zero_t = utils::mkZero(utils::getSize(t)); + Node zero_a = utils::mkZero(utils::getSize(a)); + + NodeBuilder<> nb(kind::AND); + Kind k = is_sext ? kind::BITVECTOR_SLT : kind::BITVECTOR_ULT; + nb << t.eqNode(zero_t).notNode(); + nb << a.eqNode(zero_a).notNode(); + nb << nm->mkNode(k, addxt, x) + .eqNode(nm->mkNode(kind::BITVECTOR_SGT, a, zero_a)); + return nb.constructNode(); } -} + +} // namespace bv +} // namespace theory +} // namespace CVC4 diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index bfaf517cc..6bdd2ec28 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -89,7 +89,8 @@ RewriteResponse TheoryBVRewriter::RewriteUltBv(TNode node, bool prerewrite) { RewriteResponse TheoryBVRewriter::RewriteSlt(TNode node, bool prerewrite){ Node resultNode = LinearRewriteStrategy - < RewriteRule < EvalSlt > + < RewriteRule, + RewriteRule >::apply(node); return RewriteResponse(REWRITE_DONE, resultNode); diff --git a/src/theory/bv/theory_bv_utils.h b/src/theory/bv/theory_bv_utils.h index ea2dd4fc8..d9d1183af 100644 --- a/src/theory/bv/theory_bv_utils.h +++ b/src/theory/bv/theory_bv_utils.h @@ -54,6 +54,11 @@ inline unsigned getSize(TNode node) { return node.getType().getBitVectorSize(); } +inline unsigned getSignExtendAmount(TNode node) +{ + return node.getOperator().getConst().signExtendAmount; +} + inline const bool getBit(TNode node, unsigned i) { Assert (i < utils::getSize(node) && node.getKind() == kind::CONST_BITVECTOR); -- 2.30.2