Rewrites for BitVector multiplication (#1465)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 2 Jan 2018 22:12:45 +0000 (16:12 -0600)
committerGitHub <noreply@github.com>
Tue, 2 Jan 2018 22:12:45 +0000 (16:12 -0600)
src/theory/bv/theory_bv_rewrite_rules_normalization.h
src/theory/bv/theory_bv_rewrite_rules_simplification.h
src/theory/bv/theory_bv_utils.h
test/regress/regress0/bv/Makefile.am
test/regress/regress0/bv/mul-neg-unsat.smt2 [new file with mode: 0644]
test/regress/regress0/bv/mul-negpow2.smt2 [new file with mode: 0644]

index 61f072643bfdfd04e9146e8beb984d541298f0a6..3ad733f9964613964fbc3c44d78aa87bf60a9fa1 100644 (file)
@@ -381,25 +381,7 @@ Node RewriteRule<PlusCombineLikeTerms>::apply(TNode node) {
 
 template<> inline
 bool RewriteRule<MultSimplify>::applies(TNode node) {
-  if (node.getKind() != kind::BITVECTOR_MULT) {
-    return false;
-  }
-  TNode::iterator child_it = node.begin();
-  TNode::iterator child_next = child_it + 1;
-  for(; child_next != node.end(); ++child_it, ++child_next) {
-    if ((*child_it).isConst() ||
-        !((*child_it) < (*child_next))) {
-      return true;
-    }
-  }
-  if ((*child_it).isConst()) {
-    BitVector bv = (*child_it).getConst<BitVector>();
-    if (bv == BitVector(utils::getSize(node), (unsigned) 0) ||
-        bv == BitVector(utils::getSize(node), (unsigned) 1)) {
-      return true;
-    }
-  }
-  return false;
+  return node.getKind() == kind::BITVECTOR_MULT;
 }
 
 template<> inline
@@ -408,31 +390,58 @@ Node RewriteRule<MultSimplify>::apply(TNode node) {
   unsigned size = utils::getSize(node); 
   BitVector constant(size, Integer(1));
 
-  std::vector<Node> children; 
-  for(unsigned i = 0; i < node.getNumChildren(); ++i) {
-    TNode current = node[i];
+  bool isNeg = false;
+  std::vector<Node> children;
+  for (const TNode& current : node)
+  {
     if (current.getKind() == kind::CONST_BITVECTOR) {
       BitVector value = current.getConst<BitVector>();
       constant = constant * value;
       if(constant == BitVector(size, (unsigned) 0)) {
         return utils::mkConst(size, 0); 
       }
+    }
+    else if (current.getKind() == kind::BITVECTOR_NEG)
+    {
+      isNeg = !isNeg;
+      children.push_back(current[0]);
     } else {
       children.push_back(current); 
     }
   }
+  BitVector oValue = BitVector(size, static_cast<unsigned>(1));
+  BitVector noValue = utils::mkBitVectorOnes(size);
+
+  if (children.empty())
+  {
+    Assert(!isNeg);
+    return utils::mkConst(constant);
+  }
 
   std::sort(children.begin(), children.end());
 
-  if(constant != BitVector(size, (unsigned)1)) {
-    children.push_back(utils::mkConst(constant)); 
+  if (constant == noValue)
+  {
+    isNeg = !isNeg;
   }
-  
-  if(children.size() == 0) {
-    return utils::mkConst(size, (unsigned)1); 
+  else if (constant != oValue)
+  {
+    if (isNeg)
+    {
+      isNeg = !isNeg;
+      constant = -constant;
+    }
+    children.push_back(utils::mkConst(constant));
   }
 
-  return utils::mkNode(kind::BITVECTOR_MULT, children); 
+  Node ret = utils::mkNode(kind::BITVECTOR_MULT, children);
+
+  // if negative, negate entire node
+  if (isNeg && size > 1)
+  {
+    ret = utils::mkNode(kind::BITVECTOR_NEG, ret);
+  }
+  return ret;
 }
 
 
index 98a31189074705a9484827225bb5331f664eac58..9d44d3be5ec2b18eb48533e9813ad0b36b6bf61c 100644 (file)
@@ -751,38 +751,56 @@ Node RewriteRule<NotUle>::apply(TNode node) {
  * (a * 2^k) ==> a[n-k-1:0] 0_k
  */
 
-template<> inline
-bool RewriteRule<MultPow2>::applies(TNode node) {
+template <>
+inline bool RewriteRule<MultPow2>::applies(TNode node)
+{
   if (node.getKind() != kind::BITVECTOR_MULT)
     return false;
 
-  for(unsigned i = 0; i < node.getNumChildren(); ++i) {
-    if (utils::isPow2Const(node[i])) {
+  for (const Node& cn : node)
+  {
+    bool cIsNeg = false;
+    if (utils::isPow2Const(cn, cIsNeg))
+    {
       return true; 
     }
   }
   return false; 
 }
 
-template<> inline
-Node RewriteRule<MultPow2>::apply(TNode node) {
+template <>
+inline Node RewriteRule<MultPow2>::apply(TNode node)
+{
   Debug("bv-rewrite") << "RewriteRule<MultPow2>(" << node << ")" << std::endl;
 
+  unsigned size = utils::getSize(node);
   std::vector<Node>  children;
-  unsigned exponent = 0; 
-  for(unsigned i = 0; i < node.getNumChildren(); ++i) {
-    unsigned exp = utils::isPow2Const(node[i]);
+  unsigned exponent = 0;
+  bool isNeg = false;
+  for (const Node& cn : node)
+  {
+    bool cIsNeg = false;
+    unsigned exp = utils::isPow2Const(cn, cIsNeg);
     if (exp) {
       exponent += exp - 1;
+      if (cIsNeg)
+      {
+        isNeg = !isNeg;
+      }
     }
     else {
-      children.push_back(node[i]); 
+      children.push_back(cn);
     }
   }
 
-  Node a = utils::mkNode(kind::BITVECTOR_MULT, children); 
+  Node a = utils::mkNode(kind::BITVECTOR_MULT, children);
 
-  Node extract = utils::mkExtract(a, utils::getSize(node) - exponent - 1, 0);
+  if (isNeg && size > 1)
+  {
+    a = utils::mkNode(kind::BITVECTOR_NEG, a);
+  }
+
+  Node extract = utils::mkExtract(a, size - exponent - 1, 0);
   Node zeros = utils::mkConst(exponent, 0);
   return utils::mkConcat(extract, zeros); 
 }
@@ -888,24 +906,43 @@ Node RewriteRule<NegIdemp>::apply(TNode node) {
  * (a udiv 2^k) ==> 0_k a[n-1: k]
  */
 
-template<> inline
-bool RewriteRule<UdivPow2>::applies(TNode node) {
-  return (node.getKind() == kind::BITVECTOR_UDIV_TOTAL &&
-          utils::isPow2Const(node[1]));
+template <>
+inline bool RewriteRule<UdivPow2>::applies(TNode node)
+{
+  bool isNeg = false;
+  if (node.getKind() == kind::BITVECTOR_UDIV_TOTAL
+      && utils::isPow2Const(node[1], isNeg))
+  {
+    return !isNeg;
+  }
+  return false;
 }
 
-template<> inline
-Node RewriteRule<UdivPow2>::apply(TNode node) {
+template <>
+inline Node RewriteRule<UdivPow2>::apply(TNode node)
+{
   Debug("bv-rewrite") << "RewriteRule<UdivPow2>(" << node << ")" << std::endl;
+  unsigned size = utils::getSize(node);
   Node a = node[0];
-  unsigned power = utils::isPow2Const(node[1]) -1;
-  if (power == 0) {
-    return a; 
+  bool isNeg = false;
+  unsigned power = utils::isPow2Const(node[1], isNeg) - 1;
+  Node ret;
+  if (power == 0)
+  {
+    ret = a;
   }
-  Node extract = utils::mkExtract(a, utils::getSize(node) - 1, power);
-  Node zeros = utils::mkConst(power, 0);
-  
-  return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); 
+  else
+  {
+    Node extract = utils::mkExtract(a, size - 1, power);
+    Node zeros = utils::mkConst(power, 0);
+
+    ret = utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract);
+  }
+  if (isNeg && size > 1)
+  {
+    ret = utils::mkNode(kind::BITVECTOR_NEG, ret);
+  }
+  return ret;
 }
 
 /**
@@ -950,23 +987,37 @@ inline Node RewriteRule<UdivOne>::apply(TNode node) {
  * (a urem 2^k) ==> 0_(n-k) a[k-1:0]
  */
 
-template<> inline
-bool RewriteRule<UremPow2>::applies(TNode node) {
-  return (node.getKind() == kind::BITVECTOR_UREM_TOTAL &&
-          utils::isPow2Const(node[1]));
+template <>
+inline bool RewriteRule<UremPow2>::applies(TNode node)
+{
+  bool isNeg;
+  if (node.getKind() == kind::BITVECTOR_UREM_TOTAL
+      && utils::isPow2Const(node[1], isNeg))
+  {
+    return !isNeg;
+  }
+  return false;
 }
 
-template<> inline
-Node RewriteRule<UremPow2>::apply(TNode node) {
+template <>
+inline Node RewriteRule<UremPow2>::apply(TNode node)
+{
   Debug("bv-rewrite") << "RewriteRule<UremPow2>(" << node << ")" << std::endl;
   TNode a = node[0];
-  unsigned power = utils::isPow2Const(node[1]) - 1;
-  if (power == 0) {
-    return utils::mkConst(utils::getSize(node), 0);
+  bool isNeg = false;
+  unsigned power = utils::isPow2Const(node[1], isNeg) - 1;
+  Node ret;
+  if (power == 0)
+  {
+    ret = utils::mkZero(utils::getSize(node));
+  }
+  else
+  {
+    Node extract = utils::mkExtract(a, power - 1, 0);
+    Node zeros = utils::mkZero(utils::getSize(node) - power);
+    ret = utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract);
   }
-  Node extract = utils::mkExtract(a, power - 1, 0);
-  Node zeros = utils::mkConst(utils::getSize(node) - power, 0);
-  return utils::mkNode(kind::BITVECTOR_CONCAT, zeros, extract); 
+  return ret;
 }
 
 /**
index d9d1183af07792f63682e3bd07f8e5520abdf7df..ed772b7c4da4509a04d738b122db9ce394da3e6a 100644 (file)
@@ -272,13 +272,32 @@ inline Node mkConjunction(const std::set<TNode> nodes) {
   return conjunction;
 }
 
-inline unsigned isPow2Const(TNode node) {
+/**
+ * If node is a constant of the form 2^c or -2^c, then this function returns
+ * c+1. Otherwise, this function returns 0. The flag isNeg is updated to
+ * indicate whether node is negative.
+ */
+inline unsigned isPow2Const(TNode node, bool& isNeg)
+{
   if (node.getKind() != kind::CONST_BITVECTOR) {
     return false; 
   }
 
   BitVector bv = node.getConst<BitVector>();
-  return bv.isPow2(); 
+  unsigned p = bv.isPow2();
+  if (p != 0)
+  {
+    isNeg = false;
+    return p;
+  }
+  BitVector nbv = -bv;
+  p = nbv.isPow2();
+  if (p != 0)
+  {
+    isNeg = true;
+    return p;
+  }
+  return false;
 }
 
 inline Node mkOr(const std::vector<Node>& nodes) {
index 0ae0c69e020d19eb3113e65448c5916eda0a0933..68a5f791c390980e14729f4fbbfe275c467ba0b4 100644 (file)
@@ -103,7 +103,9 @@ SMT_TESTS = \
        bv-int-collapse2.smt2 \
        bv-int-collapse2-sat.smt2 \
        divtest_2_5.smt2 \
-       divtest_2_6.smt2
+       divtest_2_6.smt2 \
+       mul-neg-unsat.smt2 \
+       mul-negpow2.smt2
 
 # This benchmark is currently disabled as it uses --check-proof
 # bench_38.delta.smt2
diff --git a/test/regress/regress0/bv/mul-neg-unsat.smt2 b/test/regress/regress0/bv/mul-neg-unsat.smt2
new file mode 100644 (file)
index 0000000..751a8a3
--- /dev/null
@@ -0,0 +1,6 @@
+(set-logic QF_BV)
+(set-info :status unsat)
+(declare-fun a () (_ BitVec 32))
+(declare-fun b () (_ BitVec 32))
+(assert (not (= (bvmul a b) (bvmul (bvneg a) (bvneg b)))))
+(check-sat)
\ No newline at end of file
diff --git a/test/regress/regress0/bv/mul-negpow2.smt2 b/test/regress/regress0/bv/mul-negpow2.smt2
new file mode 100644 (file)
index 0000000..ace776e
--- /dev/null
@@ -0,0 +1,6 @@
+(set-logic QF_BV)
+(set-info :status unsat)
+(declare-fun a () (_ BitVec 32))
+(declare-fun b () (_ BitVec 32))
+(assert (not (= (bvmul a (_ bv4294967040 32)) (bvshl (bvneg a) (_ bv8 32)))))
+(check-sat)
\ No newline at end of file