Changing the handled operators in bv2int preprocessing pass (#4970)
authoryoni206 <yoni206@users.noreply.github.com>
Thu, 3 Sep 2020 20:28:48 +0000 (13:28 -0700)
committerGitHub <noreply@github.com>
Thu, 3 Sep 2020 20:28:48 +0000 (15:28 -0500)
Some of the bit-vector operators are directly translated to integers, while others are eliminated before the translation.
This PR changes the set of operators that we eliminate (and as a consequence, also the set of operators that we handle directly):

The only bit-wise operator that is translated is bvand. The rest are now eliminated.
bvneg is now eliminated.
The various division operators are still eliminated, but using different rewrite rules.
zero-extend and sign-extend are now handled directly.
shifting is changed to favor ITEs over non-linear multiplication.

src/preprocessing/passes/bv_to_int.cpp
src/preprocessing/passes/bv_to_int.h
test/regress/regress2/bv_to_int_shifts.smt2

index 04131211361527c57029a87ac1aeb24235311c69..c0e22a0eea7ccdcfd9e8186ebb66d0e260c4a337 100644 (file)
@@ -47,16 +47,6 @@ Rational intpow2(uint64_t b)
  */
 bool oneBitAnd(bool a, bool b) { return (a && b); }
 
-bool oneBitOr(bool a, bool b) { return (a || b); }
-
-bool oneBitXor(bool a, bool b) { return a != b; }
-
-bool oneBitXnor(bool a, bool b) { return a == b; }
-
-bool oneBitNand(bool a, bool b) { return !(a && b); }
-
-bool oneBitNor(bool a, bool b) { return !(a || b); }
-
 } //end empty namespace
 
 Node BVToInt::mkRangeConstraint(Node newVar, uint64_t k)
@@ -190,21 +180,24 @@ Node BVToInt::eliminationPass(Node n)
       // eliminate operators from it
       Node currentEliminated =
           FixpointRewriteStrategy<RewriteRule<UdivZero>,
-                                  RewriteRule<SdivEliminate>,
-                                  RewriteRule<SremEliminate>,
-                                  RewriteRule<SmodEliminate>,
+                                  RewriteRule<SdivEliminateFewerBitwiseOps>,
+                                  RewriteRule<SremEliminateFewerBitwiseOps>,
+                                  RewriteRule<SmodEliminateFewerBitwiseOps>,
+                                  RewriteRule<XnorEliminate>,
+                                  RewriteRule<NandEliminate>,
+                                  RewriteRule<NorEliminate>,
+                                  RewriteRule<NegEliminate>,
+                                  RewriteRule<XorEliminate>,
+                                  RewriteRule<OrEliminate>,
+                                  RewriteRule<SubEliminate>,
                                   RewriteRule<RepeatEliminate>,
-                                  RewriteRule<ZeroExtendEliminate>,
-                                  RewriteRule<SignExtendEliminate>,
                                   RewriteRule<RotateRightEliminate>,
                                   RewriteRule<RotateLeftEliminate>,
                                   RewriteRule<CompEliminate>,
                                   RewriteRule<SleEliminate>,
                                   RewriteRule<SltEliminate>,
                                   RewriteRule<SgtEliminate>,
-                                  RewriteRule<SgeEliminate>,
-                                  RewriteRule<ShlByConst>,
-                                  RewriteRule<LshrByConst> >::apply(current);
+                                  RewriteRule<SgeEliminate>>::apply(current);
       // save in the cache
       d_eliminationCache[current] = currentEliminated;
       // also assign the eliminated now to itself to avoid revisiting.
@@ -451,23 +444,6 @@ Node BVToInt::bvToInt(Node n)
               d_bvToIntCache[current] = ite;
               break;
             }
-            case kind::BITVECTOR_NEG:
-            {
-              // (bvneg x) is 2^k-x, unless x is 0, 
-              // in which case the result should be 0.
-              // This can be expressed by (2^k-x) mod 2^k
-              // However, since mod is an expensive arithmetic operation,
-              // we represent `bvneg` using an ITE.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node pow2BvSize = pow2(bvsize);
-              Node neg =
-                  d_nm->mkNode(kind::MINUS, pow2BvSize, translated_children[0]);
-              Node isZero =
-                  d_nm->mkNode(kind::EQUAL, translated_children[0], d_zero);
-              d_bvToIntCache[current] =
-                  d_nm->mkNode(kind::ITE, isZero, d_zero, neg);
-              break;
-            }
             case kind::BITVECTOR_NOT:
             {
               uint64_t bvsize = current[0].getType().getBitVectorSize();
@@ -496,66 +472,6 @@ Node BVToInt::bvToInt(Node n)
               d_bvToIntCache[current] = newNode;
               break;
             }
-            case kind::BITVECTOR_OR:
-            {
-              // Construct an ite, based on granularity.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node newNode = createBitwiseNode(translated_children[0],
-                                               translated_children[1],
-                                               bvsize,
-                                               granularity,
-                                               &oneBitOr);
-              d_bvToIntCache[current] = newNode;
-              break;
-            }
-            case kind::BITVECTOR_XOR:
-            {
-              // Construct an ite, based on granularity.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node newNode = createBitwiseNode(translated_children[0],
-                                               translated_children[1],
-                                               bvsize,
-                                               granularity,
-                                               &oneBitXor);
-              d_bvToIntCache[current] = newNode;
-              break;
-            }
-            case kind::BITVECTOR_XNOR:
-            {
-              // Construct an ite, based on granularity.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node newNode = createBitwiseNode(translated_children[0],
-                                               translated_children[1],
-                                               bvsize,
-                                               granularity,
-                                               &oneBitXnor);
-              d_bvToIntCache[current] = newNode;
-              break;
-            }
-            case kind::BITVECTOR_NAND:
-            {
-              // Construct an ite, based on granularity.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node newNode = createBitwiseNode(translated_children[0],
-                                               translated_children[1],
-                                               bvsize,
-                                               granularity,
-                                               &oneBitNand);
-              d_bvToIntCache[current] = newNode;
-              break;
-            }
-            case kind::BITVECTOR_NOR:
-            {
-              // Construct an ite, based on granularity.
-              uint64_t bvsize = current[0].getType().getBitVectorSize();
-              Node newNode = createBitwiseNode(translated_children[0],
-                                               translated_children[1],
-                                               bvsize,
-                                               granularity,
-                                               &oneBitNor);
-              d_bvToIntCache[current] = newNode;
-              break;
-            }
             case kind::BITVECTOR_SHL:
             {
               /**
@@ -621,6 +537,67 @@ Node BVToInt::bvToInt(Node n)
               d_bvToIntCache[current] = ite;
               break;
             }
+            case kind::BITVECTOR_ZERO_EXTEND:
+            {
+              d_bvToIntCache[current] = translated_children[0];
+              break;
+            }
+            case kind::BITVECTOR_SIGN_EXTEND:
+            {
+              uint64_t bvsize = current[0].getType().getBitVectorSize();
+              Node arg = translated_children[0];
+              if (arg.isConst())
+              {
+                Rational c(arg.getConst<Rational>());
+                Rational twoToKMinusOne(intpow2(bvsize - 1));
+                uint64_t amount = bv::utils::getSignExtendAmount(current);
+                /* if the msb is 0, this is like zero_extend.
+                 *  msb is 0 <-> the value is less than 2^{bvsize-1}
+                 */
+                if (c < twoToKMinusOne || amount == 0)
+                {
+                  d_bvToIntCache[current] = arg;
+                }
+                else
+                {
+                  /* otherwise, we add the integer equivalent of
+                   * 11....1 `amount` times
+                   */
+                  Rational max_of_amount = intpow2(amount) - 1;
+                  Rational mul = max_of_amount * intpow2(bvsize);
+                  Rational sum = mul + c;
+                  Node result = d_nm->mkConst(sum);
+                  d_bvToIntCache[current] = result;
+                }
+              }
+              else
+              {
+                uint64_t amount = bv::utils::getSignExtendAmount(current);
+                if (amount == 0)
+                {
+                  d_bvToIntCache[current] = translated_children[0];
+                }
+                else
+                {
+                  Rational twoToKMinusOne(intpow2(bvsize - 1));
+                  Node minSigned = d_nm->mkConst(twoToKMinusOne);
+                  /* condition checks whether the msb is 1.
+                   * This holds when the integer value is smaller than
+                   * 100...0, which is 2^{bvsize-1}.
+                   */
+                  Node condition = d_nm->mkNode(kind::LT, arg, minSigned);
+                  Node thenResult = arg;
+                  Node left = maxInt(amount);
+                  Node mul = d_nm->mkNode(kind::MULT, left, pow2(bvsize));
+                  Node sum = d_nm->mkNode(kind::PLUS, mul, arg);
+                  Node elseResult = sum;
+                  Node ite = d_nm->mkNode(
+                      kind::ITE, condition, thenResult, elseResult);
+                  d_bvToIntCache[current] = ite;
+                }
+              }
+              break;
+            }
             case kind::BITVECTOR_CONCAT:
             {
               // (concat a b) translates to a*2^k+b, k being the bitwidth of b.
@@ -882,31 +859,38 @@ Node BVToInt::createShiftNode(vector<Node> children,
                               uint64_t bvsize,
                               bool isLeftShift)
 {
+  /**
+   * from SMT-LIB:
+   * [[(bvshl s t)]] := nat2bv[m](bv2nat([[s]]) * 2^(bv2nat([[t]])))
+   * [[(bvlshr s t)]] := nat2bv[m](bv2nat([[s]]) div 2^(bv2nat([[t]])))
+   * Since we don't have exponentiation, we use an ite.
+   * Important note: below we use INTS_DIVISION_TOTAL, which is safe here
+   * because we divide by 2^... which is never 0.
+   */
   Node x = children[0];
   Node y = children[1];
+  // shifting by const is eliminated by the theory rewriter
   Assert(!y.isConst());
-  // ite represents 2^x for every integer x from 0 to bvsize-1.
-  Node ite = pow2(0);
-  for (uint64_t i = 1; i < bvsize; i++)
+  Node ite = d_zero;
+  Node body;
+  for (uint64_t i = 0; i < bvsize; i++)
   {
+    if (isLeftShift)
+    {
+      body = d_nm->mkNode(kind::INTS_MODULUS_TOTAL,
+                          d_nm->mkNode(kind::MULT, x, pow2(i)),
+                          pow2(bvsize));
+    }
+    else
+    {
+      body = d_nm->mkNode(kind::INTS_DIVISION_TOTAL, x, pow2(i));
+    }
     ite = d_nm->mkNode(kind::ITE,
                        d_nm->mkNode(kind::EQUAL, y, d_nm->mkConst<Rational>(i)),
-                       pow2(i),
+                       body,
                        ite);
   }
-  /**
-   * from SMT-LIB:
-   * [[(bvshl s t)]] := nat2bv[m](bv2nat([[s]]) * 2^(bv2nat([[t]])))
-   * [[(bvlshr s t)]] := nat2bv[m](bv2nat([[s]]) div 2^(bv2nat([[t]])))
-   * Since we don't have exponentiation, we use the ite declared above.
-   */
-  kind::Kind_t then_kind = isLeftShift ? kind::MULT : kind::INTS_DIVISION_TOTAL;
-  return d_nm->mkNode(kind::ITE,
-                              d_nm->mkNode(kind::LT, y, d_nm->mkConst<Rational>(bvsize)),
-                              d_nm->mkNode(kind::INTS_MODULUS_TOTAL,
-                                                            d_nm->mkNode(then_kind, x, ite),
-                                                            pow2(bvsize)),
-                              d_zero);
+  return ite;
 }
 
 Node BVToInt::createITEFromTable(
index d8ee698792f0a4ce91e82be4cdd34bf733436f40..2777a36a6ef61ae64dadacc8fc57583266c9045d 100644 (file)
  ** Tr(x) = fresh_x for every bit-vector variable x, where fresh_x is a fresh
  **         integer variable.
  ** Tr(c) = the integer value of c, for any bit-vector constant c.
- ** Tr((bvadd s t)) = Tr(s) + Tr(t) mod 2^k, where k is the bit width of 
+ ** Tr((bvadd s t)) = Tr(s) + Tr(t) mod 2^k, where k is the bit width of
  **         s and t.
  ** Similar transformations are done for bvmul, bvsub, bvudiv, bvurem, bvneg,
  **         bvnot, bvconcat, bvextract
+ ** Tr((_ zero_extend m) x) = Tr(x)
+ ** Tr((_ sign_extend m) x) = ite(msb(x)=0, x, 2^k*(2^m-1) + x))
+ ** explanation: if the msb is 0, this is the same as zero_extend,
+ ** which does not change the integer value.
+ ** If the msb is 1, then the result should correspond to
+ ** concat(1...1, x), with m 1's.
+ ** m 1's is 2^m-1, and multiplying it by x's width (k) moves it
+ ** to the front.
  **
  ** Tr((bvand s t)) depends on the granularity, which is provided by the user
  ** when enabling this preprocessing pass.
index 998234a17d357b9f278e1a3aac4ae6dca9560fa6..d213b0c3dc761c585efbf2ff4c5b708e8c374cb4 100644 (file)
@@ -1,17 +1,18 @@
 ; COMMAND-LINE: --solve-bv-as-int=sum --bvand-integer-granularity=1 --no-check-models  --no-check-unsat-cores
 ; EXPECT: sat
 (set-logic QF_BV)
-(declare-fun s () (_ BitVec 64))
-(declare-fun t () (_ BitVec 64))
-(declare-fun splust () (_ BitVec 64))
-(declare-fun shift1 () (_ BitVec 64))
-(declare-fun shift2 () (_ BitVec 64))
-(declare-fun negshift1 () (_ BitVec 64))
+(declare-fun s () (_ BitVec 4))
+(declare-fun t () (_ BitVec 4))
+(declare-fun splust () (_ BitVec 4))
+(declare-fun shift1 () (_ BitVec 4))
+(declare-fun shift2 () (_ BitVec 4))
+(declare-fun negshift1 () (_ BitVec 4))
 
 (assert (= shift1 (bvlshr s splust)))
 (assert (= shift2 (bvlshr t splust)))
 (assert (= negshift1 (bvneg shift1)))
 (assert (= splust (bvadd s t)))
 (assert (distinct negshift1 shift2))
+(assert (distinct s (bvshl s (_ bv4 4))))
 
 (check-sat)