Add more general SignExtendUltConst rewriting. (#1385)
authorMathias Preiner <mathias.preiner@gmail.com>
Thu, 5 Apr 2018 18:40:34 +0000 (11:40 -0700)
committerGitHub <noreply@github.com>
Thu, 5 Apr 2018 18:40:34 +0000 (11:40 -0700)
This also adds an additional check in processAssertions to ensure that all
assertions are guaranteed to be rewritten (there was only a comment stating
this).

src/smt/smt_engine.cpp
src/theory/bv/theory_bv_rewrite_rules_simplification.h
src/theory/bv/theory_bv_rewriter.cpp

index abefaf215ec2bc4d7e60f0bdd010522e34f1f110..70dc046690aceb8a8182309e9dcdb574b60bed81 100644 (file)
@@ -4156,6 +4156,12 @@ void SmtEnginePrivate::processAssertions() {
   dumpAssertions("post-substitution", d_assertions);
 
   // Assertions ARE guaranteed to be rewritten by this point
+#ifdef CVC4_ASSERTIONS
+  for (unsigned i = 0; i < d_assertions.size(); ++i)
+  {
+    Assert(Rewriter::rewrite(d_assertions[i]) == d_assertions[i]);
+  }
+#endif
 
   // Lift bit-vectors of size 1 to bool
   if(options::bitvectorToBool()) {
index a4be19253f47f1c548024858dd2c58cf66e46c27..397385996a1d92d7d6eb9360266c97adf2ee84df 100644 (file)
@@ -1344,55 +1344,106 @@ inline Node RewriteRule<ZeroExtendUltConst>::apply(TNode node) {
  *
  * Rewrite sign_extend(x^n,m) < c^n+m to
  *
- *   x < c[n-1:0]   if c <= (1 << (n - 1)).
+ *   x < c[n-1:0]   if (c <= (1 << (n - 1))) || (c >= (~0 << (n - 1)))
+ *   x[n-1:n-1] = 0 if (1 << (n - 1)) < c <= (~0 << (n - 1)).
  *
  * Rewrite c^n+m < sign_extend(x^n,m) to
  *
- *   c[n-1:0] < x   if c < (1 << (n - 1)).
+ *   c[n-1:0] < x   if (c < (1 << (n - 1))) || (c >= ~(1 << (n-1)))
+ *   x[n-1:n-1] = 1 if ~(~0 << (n-1)) <= c <= ~(1 << (n-1))
+ *
+ * where ~(~0 << (n - 1)) == (1 << (n - 1)) - 1
+ *
  */
 template <>
-inline bool RewriteRule<SignExtendUltConst>::applies(TNode node) {
-  if (node.getKind() == kind::BITVECTOR_ULT &&
-      ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
-        node[1].isConst()) ||
-       (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND &&
-        node[0].isConst()))) {
-    TNode t, c;
+inline bool RewriteRule<SignExtendUltConst>::applies(TNode node)
+{
+  if (node.getKind() == kind::BITVECTOR_ULT
+      && ((node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND
+           && node[1].isConst())
+          || (node[1].getKind() == kind::BITVECTOR_SIGN_EXTEND
+              && node[0].isConst())))
+  {
+    TNode x, c;
     bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
-    if (is_lhs) {
-      t = node[0][0];
+    if (is_lhs)
+    {
+      x = node[0][0];
       c = node[1];
-    } else {
-      t = node[1][0];
+    }
+    else
+    {
+      x = node[1][0];
       c = node[0];
     }
     BitVector bv_c = c.getConst<BitVector>();
-    BitVector bv_max =
-        BitVector(utils::getSize(c)).setBit(utils::getSize(t) - 1);
-
-    return (is_lhs && bv_c <= bv_max) || (!is_lhs && bv_c < bv_max);
+    unsigned size_c = utils::getSize(c);
+    unsigned msb_x_pos = utils::getSize(x) - 1;
+    // (1 << (n - 1)))
+    BitVector bv_msb_x = BitVector(size_c).setBit(msb_x_pos);
+    // (~0 << (n - 1))
+    BitVector bv_upper_bits =
+        (~BitVector(size_c)).leftShift(BitVector(size_c, msb_x_pos));
+
+    return (is_lhs
+            && (bv_c <= bv_msb_x || bv_c >= bv_upper_bits
+                || (bv_msb_x < bv_c && bv_c <= bv_upper_bits)))
+           || (!is_lhs
+               && (bv_c < bv_msb_x || bv_c >= ~bv_msb_x
+                   || (~bv_upper_bits <= bv_c && bv_c <= ~bv_msb_x)));
   }
   return false;
 }
 
 template <>
-inline Node RewriteRule<SignExtendUltConst>::apply(TNode node) {
-  TNode t, c;
+inline Node RewriteRule<SignExtendUltConst>::apply(TNode node)
+{
+  TNode x, c;
   bool is_lhs = node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND;
-  if (is_lhs) {
-    t = node[0][0];
+  if (is_lhs)
+  {
+    x = node[0][0];
     c = node[1];
-  } else {
-    t = node[1][0];
+  }
+  else
+  {
+    x = node[1][0];
     c = node[0];
   }
-  Node c_lo =
-      utils::mkConst(c.getConst<BitVector>().extract(utils::getSize(t) - 1, 0));
+  BitVector bv_c = c.getConst<BitVector>();
 
-  if (is_lhs) {
-    return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, t, c_lo);
+  unsigned size_c = utils::getSize(c);
+  unsigned msb_x_pos = utils::getSize(x) - 1;
+  Node c_lo = utils::mkConst(bv_c.extract(msb_x_pos, 0));
+  // (1 << (n - 1)))
+  BitVector bv_msb_x = BitVector(size_c).setBit(msb_x_pos);
+  // (~0 << (n - 1))
+  BitVector bv_upper_bits =
+      (~BitVector(size_c)).leftShift(BitVector(size_c, msb_x_pos));
+
+  NodeManager* nm = NodeManager::currentNM();
+  if (is_lhs)
+  {
+    // x[n-1:n-1] = 0
+    if (bv_msb_x < bv_c && bv_c <= bv_upper_bits)
+    {
+      Node msb_x = utils::mkExtract(x, msb_x_pos, msb_x_pos);
+      return nm->mkNode(kind::EQUAL, msb_x, utils::mkZero(1));
+    }
+    // x < c[n-1:0]
+    Assert(bv_c <= bv_msb_x || bv_c >= bv_upper_bits);
+    return nm->mkNode(kind::BITVECTOR_ULT, x, c_lo);
   }
-  return NodeManager::currentNM()->mkNode(kind::BITVECTOR_ULT, c_lo, t);
+
+  // x[n-1:n-1] = 1
+  if (~bv_upper_bits <= bv_c && bv_c <= ~bv_msb_x)
+  {
+    Node msb_x = utils::mkExtract(x, msb_x_pos, msb_x_pos);
+    return nm->mkNode(kind::EQUAL, msb_x, utils::mkOne(1));
+  }
+  // c[n-1:0] < x
+  Assert(bv_c < bv_msb_x || bv_c >= ~bv_msb_x);
+  return nm->mkNode(kind::BITVECTOR_ULT, c_lo, x);
 }
 
 template<> inline
index 6bdd2ec28c161790aa1ed9c80a8f116998cc254a..c05e679476cc23801ebcb33e173300a1803e73f9 100644 (file)
@@ -73,8 +73,9 @@ RewriteResponse TheoryBVRewriter::RewriteUlt(TNode node, bool prerewrite) {
       RewriteRule<SignExtendUltConst>,
       RewriteRule<ZeroExtendUltConst>
        >::apply(node);
-  
-  return RewriteResponse(REWRITE_DONE, resultNode); 
+
+  return RewriteResponse(resultNode == node ? REWRITE_DONE : REWRITE_AGAIN_FULL,
+                         resultNode);
 }
 
 RewriteResponse TheoryBVRewriter::RewriteUltBv(TNode node, bool prerewrite) {