Add rewriting rules for Eq/Ult with sign_extend and constants. (#1258)
authorMathias Preiner <mathias.preiner@gmail.com>
Sat, 21 Oct 2017 04:03:04 +0000 (21:03 -0700)
committerGitHub <noreply@github.com>
Sat, 21 Oct 2017 04:03:04 +0000 (21:03 -0700)
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv_rewrite_rules.h
src/theory/bv/theory_bv_rewrite_rules_simplification.h
src/theory/bv/theory_bv_rewriter.cpp

index 116903ea6216bbdc79d16e3d6a47df54a94ef348..e03cecdd929883e3889435444a494d886c0659f5 100644 (file)
@@ -798,6 +798,10 @@ Node TheoryBV::ppRewrite(TNode t)
     } else {
       res = t;
     }
+  } else if (RewriteRule<SignExtendEqConst>::applies(t)) {
+    res = RewriteRule<SignExtendEqConst>::run<false>(t);
+  } else if (RewriteRule<ZeroExtendEqConst>::applies(t)) {
+    res = RewriteRule<ZeroExtendEqConst>::run<false>(t);
   }
 
 
index 9f148d8233585f31abf40d21382e829edf008fb5..4d0f8033eea78b92fe2a8faf74e9ab21220c3a8e 100644 (file)
@@ -144,6 +144,10 @@ enum RewriteRuleId {
   SltZero,
   ZeroUlt,
   MergeSignExtend,
+  SignExtendEqConst,
+  ZeroExtendEqConst,
+  SignExtendUltConst,
+  ZeroExtendUltConst,
 
   /// normalization rules
   ExtractBitwise,
@@ -303,6 +307,10 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
   case SltZero : out << "SltZero"; return out;
   case ZeroUlt : out << "ZeroUlt"; return out;
   case MergeSignExtend : out << "MergeSignExtend"; return out;
+  case SignExtendEqConst: out << "SignExtendEqConst"; return out;
+  case ZeroExtendEqConst: out << "ZeroExtendEqConst"; return out;
+  case SignExtendUltConst: out << "SignExtendUltConst"; return out;
+  case ZeroExtendUltConst: out << "ZeroExtendUltConst"; return out;
     
   case UleEliminate : out << "UleEliminate"; return out;
   case BitwiseSlicing : out << "BitwiseSlicing"; return out;
@@ -533,6 +541,10 @@ struct AllRewriteRules {
   RewriteRule<IsPowerOfTwo> rule121;
   RewriteRule<RedorEliminate> rule122;
   RewriteRule<RedandEliminate> rule123;
+  RewriteRule<SignExtendEqConst> rule124;
+  RewriteRule<ZeroExtendEqConst> rule125;
+  RewriteRule<SignExtendUltConst> rule126;
+  RewriteRule<ZeroExtendUltConst> rule127;
 };
 
 template<> inline
index c7247f26055361d8627f8a0c171074048b12de85..24e5fb5e88e83a5be9b008e10f1c185759a98b1b 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "theory/bv/theory_bv_rewrite_rules.h"
 #include "theory/bv/theory_bv_utils.h"
+#include "theory/rewriter.h"
 
 namespace CVC4 {
 namespace theory {
@@ -1105,6 +1106,201 @@ Node RewriteRule<MergeSignExtend>::apply(TNode node) {
   return res;
 }
 
+/**
+ * ZeroExtendEqConst
+ *
+ * Rewrite zero_extend(x^n, m) = c^n+m to
+ *
+ *   false         if c[n+m-1:n] != 0
+ *   x = c[n-1:0]  otherwise.
+ */
+template <>
+inline bool RewriteRule<ZeroExtendEqConst>::applies(TNode node) {
+  return node.getKind() == kind::EQUAL &&
+         ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+           node[1].isConst()) ||
+          (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+           node[0].isConst()));
+}
+
+template <>
+inline Node RewriteRule<ZeroExtendEqConst>::apply(TNode node) {
+  TNode t, c;
+  if (node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND) {
+    t = node[0][0];
+    c = node[1];
+  } else {
+    t = node[1][0];
+    c = node[0];
+  }
+  BitVector c_hi =
+      c.getConst<BitVector>().extract(utils::getSize(c) - 1, utils::getSize(t));
+  BitVector c_lo = c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0);
+  BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+  if (c_hi == zero) {
+    return NodeManager::currentNM()->mkNode(kind::EQUAL, t,
+                                            utils::mkConst(c_lo));
+  }
+  return utils::mkFalse();
+}
+
+/**
+ * SignExtendEqConst
+ *
+ * Rewrite sign_extend(x^n, m) = c^n+m to
+ *
+ *   x = c[n-1:0]   if (c[n-1:n-1] == 0 && c[n+m-1:n] == 0) ||
+ *                     (c[n-1:n-1] == 1 && c[n+m-1:n] == ~0)
+ *   false          otherwise.
+ */
+template <>
+inline bool RewriteRule<SignExtendEqConst>::applies(TNode node) {
+  return node.getKind() == kind::EQUAL &&
+         ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+           node[1].isConst()) ||
+          (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+           node[0].isConst()));
+}
+
+template <>
+inline Node RewriteRule<SignExtendEqConst>::apply(TNode node) {
+  TNode t, c;
+  if (node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND) {
+    t = node[0][0];
+    c = node[1];
+  } else {
+    t = node[1][0];
+    c = node[0];
+  }
+  unsigned pos_msb_t = utils::getSize(t) - 1;
+  BitVector c_hi =
+      c.getConst<BitVector>().extract(utils::getSize(c) - 1, pos_msb_t);
+  BitVector c_lo = c.getConst<BitVector>().extract(pos_msb_t, 0);
+  BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+  if (c_hi == zero || c_hi == ~zero) {
+    return NodeManager::currentNM()->mkNode(kind::EQUAL, t,
+                                            utils::mkConst(c_lo));
+  }
+  return utils::mkFalse();
+}
+
+/**
+ * ZeroExtendUltConst
+ *
+ * Rewrite zero_extend(x^n,m) < c^n+m to
+ *
+ *   x < c[n-1:0]   if c[n+m-1:n] == 0.
+ *
+ * Rewrite c^n+m < Rewrite zero_extend(x^n,m) to
+ *
+ *   c[n-1:0] < x   if c[n+m-1:n] == 0.
+ */
+template <>
+inline bool RewriteRule<ZeroExtendUltConst>::applies(TNode node) {
+  if (node.getKind() == kind::BITVECTOR_ULT &&
+      ((node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+        node[1].isConst()) ||
+       (node[1].getKind() == kind::BITVECTOR_ZERO_EXTEND &&
+        node[0].isConst()))) {
+    TNode t, c;
+    bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND;
+    if (is_lhs) {
+      t = node[0][0];
+      c = node[1];
+    } else {
+      t = node[1][0];
+      c = node[0];
+    }
+    BitVector bv_c = c.getConst<BitVector>();
+    BitVector bv_max =
+        BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1);
+
+    BitVector c_hi = c.getConst<BitVector>().extract(utils::getSize(c) - 1,
+                                                     utils::getSize(t));
+    BitVector zero = BitVector(c_hi.getSize(), Integer(0));
+
+    return c_hi == zero;
+  }
+  return false;
+}
+
+template <>
+inline Node RewriteRule<ZeroExtendUltConst>::apply(TNode node) {
+  TNode t, c;
+  bool is_lhs = node[0].getKind() == kind::BITVECTOR_ZERO_EXTEND;
+  if (is_lhs) {
+    t = node[0][0];
+    c = node[1];
+  } else {
+    t = node[1][0];
+    c = node[0];
+  }
+  Node c_lo =
+      utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0));
+
+  if (is_lhs) {
+    return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo);
+  }
+  return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t);
+}
+
+/**
+ * SignExtendUltConst
+ *
+ * Rewrite sign_extend(x^n,m) < c^n+m to
+ *
+ *   x < c[n-1:0]   if c <= (1 << (n - 1)).
+ *
+ * Rewrite c^n+m < sign_extend(x^n,m) to
+ *
+ *   c[n-1:0] < x   if c < (1 << (n - 1)).
+ */
+template <>
+inline bool RewriteRule<SignExtendUltConst>::applies(TNode node) {
+  if (node.getKind() == kind::BITVECTOR_ULT &&
+      ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+        node[1].isConst()) ||
+       (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
+        node[0].isConst()))) {
+    TNode t, c;
+    bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
+    if (is_lhs) {
+      t = node[0][0];
+      c = node[1];
+    } else {
+      t = node[1][0];
+      c = node[0];
+    }
+    BitVector bv_c = c.getConst<BitVector>();
+    BitVector bv_max =
+        BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1);
+
+    return (is_lhs && bv_c <= bv_max) || (!is_lhs && bv_c < bv_max);
+  }
+  return false;
+}
+
+template <>
+inline Node RewriteRule<SignExtendUltConst>::apply(TNode node) {
+  TNode t, c;
+  bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
+  if (is_lhs) {
+    t = node[0][0];
+    c = node[1];
+  } else {
+    t = node[1][0];
+    c = node[0];
+  }
+  Node c_lo =
+      utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0));
+
+  if (is_lhs) {
+    return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo);
+  }
+  return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t);
+}
 
 template<> inline
 bool RewriteRule<MultSlice>::applies(TNode node) {
index f3185bc13f7d665f2ca15cd2b1224f6baeed01d5..bfaf517cc9733330c83ed5c6fae7a8b5f9f1a1fb 100644 (file)
@@ -68,10 +68,10 @@ RewriteResponse TheoryBVRewriter::postRewrite(TNode node) {
 RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) {
   // reduce common subexpressions on both sides
   Node resultNode = LinearRewriteStrategy
-    < RewriteRule<EvalUlt>,
-      // if both arguments are constants evaluates
-      RewriteRule<UltZero>
-      // a < 0 rewrites to false
+    < RewriteRule<EvalUlt>, // if both arguments are constants evaluates
+      RewriteRule<UltZero>, // a < 0 rewrites to false,
+      RewriteRule<SignExtendUltConst>,
+      RewriteRule<ZeroExtendUltConst>
        >::apply(node);
   
   return RewriteResponse(REWRITE_DONE, resultNode);