From: Mathias Preiner Date: Sat, 21 Oct 2017 04:03:04 +0000 (-0700) Subject: Add rewriting rules for Eq/Ult with sign_extend and constants. (#1258) X-Git-Tag: cvc5-1.0.0~5544 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=6b5c27d7f634eb5985ce455989fcda36e1261929;p=cvc5.git Add rewriting rules for Eq/Ult with sign_extend and constants. (#1258) --- diff --git a/src/theory/bv/theory_bv.cpp b/src/theory/bv/theory_bv.cpp index 116903ea6..e03cecdd9 100644 --- a/src/theory/bv/theory_bv.cpp +++ b/src/theory/bv/theory_bv.cpp @@ -798,6 +798,10 @@ Node TheoryBV::ppRewrite(TNode t) } else { res = t; } + } else if (RewriteRule::applies(t)) { + res = RewriteRule::run(t); + } else if (RewriteRule::applies(t)) { + res = RewriteRule::run(t); } diff --git a/src/theory/bv/theory_bv_rewrite_rules.h b/src/theory/bv/theory_bv_rewrite_rules.h index 9f148d823..4d0f8033e 100644 --- a/src/theory/bv/theory_bv_rewrite_rules.h +++ b/src/theory/bv/theory_bv_rewrite_rules.h @@ -144,6 +144,10 @@ enum RewriteRuleId { SltZero, ZeroUlt, MergeSignExtend, + SignExtendEqConst, + ZeroExtendEqConst, + SignExtendUltConst, + ZeroExtendUltConst, /// normalization rules ExtractBitwise, @@ -303,6 +307,10 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) { case SltZero : out << "SltZero"; return out; case ZeroUlt : out << "ZeroUlt"; return out; case MergeSignExtend : out << "MergeSignExtend"; return out; + case SignExtendEqConst: out << "SignExtendEqConst"; return out; + case ZeroExtendEqConst: out << "ZeroExtendEqConst"; return out; + case SignExtendUltConst: out << "SignExtendUltConst"; return out; + case ZeroExtendUltConst: out << "ZeroExtendUltConst"; return out; case UleEliminate : out << "UleEliminate"; return out; case BitwiseSlicing : out << "BitwiseSlicing"; return out; @@ -533,6 +541,10 @@ struct AllRewriteRules { RewriteRule rule121; RewriteRule rule122; RewriteRule rule123; + RewriteRule rule124; + RewriteRule rule125; + RewriteRule rule126; + RewriteRule rule127; }; template<> inline diff --git a/src/theory/bv/theory_bv_rewrite_rules_simplification.h b/src/theory/bv/theory_bv_rewrite_rules_simplification.h index c7247f260..24e5fb5e8 100644 --- a/src/theory/bv/theory_bv_rewrite_rules_simplification.h +++ b/src/theory/bv/theory_bv_rewrite_rules_simplification.h @@ -21,6 +21,7 @@ #include "theory/bv/theory_bv_rewrite_rules.h" #include "theory/bv/theory_bv_utils.h" +#include "theory/rewriter.h" namespace CVC4 { namespace theory { @@ -1105,6 +1106,201 @@ Node RewriteRule::apply(TNode node) { return res; } +/** + * ZeroExtendEqConst + * + * Rewrite zero_extend(x^n, m) = c^n+m to + * + * false if c[n+m-1:n] != 0 + * x = c[n-1:0] otherwise. + */ +template <> +inline bool RewriteRule::applies(TNode node) { + return node.getKind() == kind::EQUAL && + ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[0].isConst())); +} + +template <> +inline Node RewriteRule::apply(TNode node) { + TNode t, c; + if (node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector c_hi = + c.getConst().extract(utils::getSize(c) - 1, utils::getSize(t)); + BitVector c_lo = c.getConst().extract(utils::getSize(t) - 1, 0); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + if (c_hi == zero) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, t, + utils::mkConst(c_lo)); + } + return utils::mkFalse(); +} + +/** + * SignExtendEqConst + * + * Rewrite sign_extend(x^n, m) = c^n+m to + * + * x = c[n-1:0] if (c[n-1:n-1] == 0 && c[n+m-1:n] == 0) || + * (c[n-1:n-1] == 1 && c[n+m-1:n] == ~0) + * false otherwise. + */ +template <> +inline bool RewriteRule::applies(TNode node) { + return node.getKind() == kind::EQUAL && + ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[0].isConst())); +} + +template <> +inline Node RewriteRule::apply(TNode node) { + TNode t, c; + if (node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + unsigned pos_msb_t = utils::getSize(t) - 1; + BitVector c_hi = + c.getConst().extract(utils::getSize(c) - 1, pos_msb_t); + BitVector c_lo = c.getConst().extract(pos_msb_t, 0); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + if (c_hi == zero || c_hi == ~zero) { + return NodeManager::currentNM()->mkNode(kind::EQUAL, t, + utils::mkConst(c_lo)); + } + return utils::mkFalse(); +} + +/** + * ZeroExtendUltConst + * + * Rewrite zero_extend(x^n,m) < c^n+m to + * + * x < c[n-1:0] if c[n+m-1:n] == 0. + * + * Rewrite c^n+m < Rewrite zero_extend(x^n,m) to + * + * c[n-1:0] < x if c[n+m-1:n] == 0. + */ +template <> +inline bool RewriteRule::applies(TNode node) { + if (node.getKind() == kind::BITVECTOR_ULT && + ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND && + node[0].isConst()))) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector bv_c = c.getConst(); + BitVector bv_max = + BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1); + + BitVector c_hi = c.getConst().extract(utils::getSize(c) - 1, + utils::getSize(t)); + BitVector zero = BitVector(c_hi.getSize(), Integer(0)); + + return c_hi == zero; + } + return false; +} + +template <> +inline Node RewriteRule::apply(TNode node) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + Node c_lo = + utils::mkConst(c.getConst().extract(utils::getSize(t) - 1, 0)); + + if (is_lhs) { + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo); + } + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t); +} + +/** + * SignExtendUltConst + * + * Rewrite sign_extend(x^n,m) < c^n+m to + * + * x < c[n-1:0] if c <= (1 << (n - 1)). + * + * Rewrite c^n+m < sign_extend(x^n,m) to + * + * c[n-1:0] < x if c < (1 << (n - 1)). + */ +template <> +inline bool RewriteRule::applies(TNode node) { + if (node.getKind() == kind::BITVECTOR_ULT && + ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[1].isConst()) || + (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND && + node[0].isConst()))) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + BitVector bv_c = c.getConst(); + BitVector bv_max = + BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1); + + return (is_lhs && bv_c <= bv_max) || (!is_lhs && bv_c < bv_max); + } + return false; +} + +template <> +inline Node RewriteRule::apply(TNode node) { + TNode t, c; + bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND; + if (is_lhs) { + t = node[0][0]; + c = node[1]; + } else { + t = node[1][0]; + c = node[0]; + } + Node c_lo = + utils::mkConst(c.getConst().extract(utils::getSize(t) - 1, 0)); + + if (is_lhs) { + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo); + } + return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t); +} template<> inline bool RewriteRule::applies(TNode node) { diff --git a/src/theory/bv/theory_bv_rewriter.cpp b/src/theory/bv/theory_bv_rewriter.cpp index f3185bc13..bfaf517cc 100644 --- a/src/theory/bv/theory_bv_rewriter.cpp +++ b/src/theory/bv/theory_bv_rewriter.cpp @@ -68,10 +68,10 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) { RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) { // reduce common subexpressions on both sides Node resultNode = LinearRewriteStrategy - < RewriteRule, - // if both arguments are constants evaluates - RewriteRule - // a < 0 rewrites to false + < RewriteRule, // if both arguments are constants evaluates + RewriteRule, // a < 0 rewrites to false, + RewriteRule, + RewriteRule >::apply(node); return RewriteResponse(REWRITE_DONE, resultNode);