added several rewrite rules (BitwiseSlicing, Ule/SleEliminate, ExtractSignExtend...
authorlianah <lianahady@gmail.com>
Tue, 30 Apr 2013 17:42:50 +0000 (13:42 -0400)
committerlianah <lianahady@gmail.com>
Tue, 30 Apr 2013 19:54:24 +0000 (15:54 -0400)
src/prop/bvminisat/bvminisat.cpp
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv_rewrite_rules.h
src/theory/bv/theory_bv_rewrite_rules_normalization.h
src/theory/bv/theory_bv_rewriter.cpp
src/theory/bv/theory_bv_utils.h

index ab25fa6cbec8ccfbc12c83a8fee52bddba0c7f68..fa5f53113021b23c2f6da76e6388d0ae91c09aa5 100644 (file)
@@ -101,6 +101,7 @@ void BVMinisatSatSolver::interrupt(){
 }
 
 SatValue BVMinisatSatSolver::solve(){
+  ++d_statistics.d_statCallsToSolve;
   return toSatLiteralValue(d_minisat->solve());
 }
 
@@ -121,30 +122,6 @@ SatValue BVMinisatSatSolver::solve(long unsigned int& resource){
   return result;
 }
 
-// SatValue BVMinisatSatSolver::solve(const context::CDList<SatLiteral> & assumptions, bool only_bcp){
-//   ++d_solveCount;
-//   ++d_statistics.d_statCallsToSolve;
-
-//   Debug("sat::minisat") << "Solve with assumptions ";
-//   context::CDList<SatLiteral>::const_iterator it = assumptions.begin();
-//   BVMinisat::vec<BVMinisat::Lit> assump;
-//   for(; it!= assumptions.end(); ++it) {
-//     SatLiteral lit = *it;
-//     Debug("sat::minisat") << lit <<" ";
-//     assump.push(toMinisatLit(lit));
-//   }
-//   Debug("sat::minisat") <<"\n";
-
-//   clock_t begin, end;
-//   begin = clock();
-//   d_minisat->setOnlyBCP(only_bcp); 
-//   SatLiteralValue result = toSatLiteralValue(d_minisat->solve(assump));
-//   end = clock();
-//   d_statistics.d_statSolveTime = d_statistics.d_statSolveTime.getData() + (end - begin)/(double)CLOCKS_PER_SEC; 
-//  return result;
-// }
-
-
 void BVMinisatSatSolver::getUnsatCore(SatClause& unsatCore) {
   // TODO add assertion to check the call was after an unsat call
   for (int i = 0; i < d_minisat->conflict.size(); ++i) {
index 953f9b3e5bfcb52756e781ef4f1b372cb054b557..b2f91e07054d19d93ef545240ac9ceb3cda9444d 100644 (file)
@@ -132,16 +132,22 @@ void TheoryBV::checkForLemma(TNode fact) {
       TNode urem = fact[0];
       TNode result = fact[1];
       TNode divisor = urem[1]; 
-      Node result_ult_div = utils::mkNode(kind::BITVECTOR_ULT, result, divisor);
-      Node split = utils::mkNode(kind::OR, utils::mkNode(kind::NOT, fact), result_ult_div);
+      Node result_ult_div = mkNode(kind::BITVECTOR_ULT, result, divisor);
+      Node divisor_eq_0 = mkNode(kind::EQUAL,
+                                 divisor,
+                                 mkConst(BitVector(getSize(divisor), 0u)));  
+      Node split = utils::mkNode(kind::OR, divisor_eq_0, mkNode(kind::NOT, fact), result_ult_div);
       lemma(split);
     }
     if (fact[1].getKind() == kind::BITVECTOR_UREM_TOTAL) {
       TNode urem = fact[1];
       TNode result = fact[0];
       TNode divisor = urem[1]; 
-      Node result_ult_div = utils::mkNode(kind::BITVECTOR_ULT, result, divisor);
-      Node split = utils::mkNode(kind::OR, utils::mkNode(kind::NOT, fact), result_ult_div);
+      Node result_ult_div = mkNode(kind::BITVECTOR_ULT, result, divisor);
+      Node divisor_eq_0 = mkNode(kind::EQUAL,
+                                  divisor,
+                                  mkConst(BitVector(getSize(divisor), 0u)));  
+      Node split = utils::mkNode(kind::OR, divisor_eq_0, mkNode(kind::NOT, fact), result_ult_div);
       lemma(split);
     }
   }
index d362fa603509072ecd94e8b77b2dcc4eb734dc20..baaf7e13369596e2b9525cd6b1e1c4e0297d9791 100644 (file)
@@ -52,6 +52,7 @@ enum RewriteRuleId {
   SubEliminate,
   SltEliminate,
   SleEliminate,
+  UleEliminate, 
   CompEliminate,
   RepeatEliminate,
   RotateLeftEliminate,
@@ -135,6 +136,7 @@ enum RewriteRuleId {
   ExtractNot,
   ExtractArith,
   ExtractArith2,
+  ExtractSignExtend,
   DoubleNeg,
   NegMult,
   NegSub,
@@ -152,7 +154,7 @@ enum RewriteRuleId {
   AndSimplify,
   OrSimplify,
   XorSimplify,
-  UleEliminate, 
+  BitwiseSlicing,
   // rules to simplify bitblasting
   BBPlusNeg
  };
@@ -270,7 +272,9 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
   case UltOne : out << "UltOne"; return out;
   case SltZero : out << "SltZero"; return out;
   case ZeroUlt : out << "ZeroUlt"; return out;
-  case UleEliminate : out << "UleEliminate"; return out; 
+  case UleEliminate : out << "UleEliminate"; return out;
+  case BitwiseSlicing : out << "BitwiseSlicing"; return out;
+  case ExtractSignExtend : out << "ExtractSignExtend"; return out; 
   default:
     Unreachable();
   }
index 4ba09ef67fc2f97f38ae615c4b075a2165c6cf7b..035bd4469bb0c68e2ace1de449db081711a88407 100644 (file)
@@ -72,6 +72,58 @@ Node RewriteRule<ExtractNot>::apply(TNode node) {
   return utils::mkNode(kind::BITVECTOR_NOT, a); 
 }
 
+/** 
+ * ExtractSignExtend
+ * 
+ * (sign_extend k x) [i:j] => pushes extract in
+ * 
+ * @return 
+ */
+
+template<> inline
+bool RewriteRule<ExtractSignExtend>::applies(TNode node) {
+  if (node.getKind() == kind::BITVECTOR_EXTRACT &&
+      node[0].getKind() == kind::BITVECTOR_SIGN_EXTEND) {
+    return true; 
+  }
+  return false; 
+}
+
+template<> inline
+Node RewriteRule<ExtractSignExtend>::apply(TNode node) {
+  Debug("bv-rewrite") << "RewriteRule<ExtractSignExtend>(" << node << ")" << std::endl;
+  TNode extendee = node[0][0]; 
+  unsigned extendee_size = utils::getSize(extendee);
+
+  unsigned high = utils::getExtractHigh(node);
+  unsigned low = utils::getExtractLow(node); 
+
+  Node resultNode; 
+  // extract falls on extendee
+  if (high < extendee_size) {
+    resultNode = utils::mkExtract(extendee, high, low); 
+  } else if (low < extendee_size && high >= extendee_size) {
+    // if extract overlaps sign extend and extendee
+    Node low_extract = utils::mkExtract(extendee, extendee_size - 1, low);
+    unsigned new_ammount = high - extendee_size + 1;
+    resultNode = utils::mkSignExtend(low_extract, new_ammount); 
+  } else {
+    // extract only over sign extend
+    Assert (low >= extendee_size);
+    unsigned top = utils::getSize(extendee) - 1; 
+    Node most_significant_bit = utils::mkExtract(extendee, top, top);
+    std::vector<Node> bits;
+    for (unsigned i = 0; i < high - low + 1; ++i) {
+      bits.push_back(most_significant_bit); 
+    }
+    resultNode =  utils::mkNode(kind::BITVECTOR_CONCAT, bits);
+  }
+  Debug("bv-rewrite") << "                           =>" << resultNode << std::endl;
+  return resultNode; 
+}
+
+
+
 /**
  * ExtractArith
  * 
@@ -1032,19 +1084,84 @@ Node RewriteRule<XorSimplify>::apply(TNode node) {
 }
 
 
+/** 
+ * BitwiseSlicing
+ * 
+ * (a bvand c) ==> (concat (bvand a[i0:j0] c0) ... (bvand a[in:jn] cn))
+ *  where c0,..., cn are maximally continuous substrings of 0 or 1 in the constant c 
+ *
+ * Note: this rule assumes AndSimplify has already been called on the node
+ */
+template<> inline
+bool RewriteRule<BitwiseSlicing>::applies(TNode node) {
+  if ((node.getKind() != kind::BITVECTOR_AND &&
+      node.getKind() != kind::BITVECTOR_OR &&
+      node.getKind() != kind::BITVECTOR_XOR) ||
+      utils::getSize(node) == 1)
+    return false; 
+  
+  for (unsigned i = 0; i < node.getNumChildren(); ++i) {
+    if (node[i].getKind() == kind::CONST_BITVECTOR) {
+      BitVector constant = node[i].getConst<BitVector>();
+      // we do not apply the rule if the constant is all 0s or all 1s
+      if (constant == BitVector(utils::getSize(node), 0u)) 
+        return false; 
+      
+      for (unsigned i = 0; i < constant.getSize(); ++i) {
+        if (!constant.isBitSet(i)) 
+          return true; 
+      }
+    }
+  }
+  return false; 
+}
 
+template<> inline
+Node RewriteRule<BitwiseSlicing>::apply(TNode node) {
+  Debug("bv-rewrite") << "RewriteRule<BitwiseSlicing>(" << node << ")" << std::endl;
+  // get the constant
+  bool found_constant = false;
+  TNode constant;
+  std::vector<Node> other_children; 
+  for (unsigned i = 0; i < node.getNumChildren(); ++i) {
+    if (node[i].getKind() == kind::CONST_BITVECTOR) {
+      constant = node[i];
+      Assert (!found_constant); 
+      found_constant = true; 
+    } else {
+      other_children.push_back(node[i]); 
+    }
+  }
+  Assert (found_constant && other_children.size() == node.getNumChildren() - 1);
 
-// template<> inline
-// bool RewriteRule<AndSimplify>::applies(TNode node) {
-//   return (node.getKind() == kind::BITVECTOR_AND);
-// }
-
-// template<> inline
-// Node RewriteRule<AndSimplify>::apply(TNode node) {
-//   Debug("bv-rewrite") << "RewriteRule<AndSimplify>(" << node << ")" << std::endl;
-//   return resultNode;
-// }
+  Node other = utils::mkNode(node.getKind(), other_children);
+  
+  BitVector bv_constant = constant.getConst<BitVector>();
+  std::vector<Node> concat_children; 
+  int start = bv_constant.getSize() - 1;
+  int end = start;
+  for (int i = end - 1; i >= 0; --i) {
+    if (bv_constant.isBitSet(i + 1) != bv_constant.isBitSet(i)) {
+      Node other_extract = utils::mkExtract(other, end, start);
+      Node const_extract = utils::mkExtract(constant, end, start);
+      Node bitwise_op = utils::mkNode(node.getKind(), const_extract, other_extract);
+      concat_children.push_back(bitwise_op);
+      start = end = i; 
+    } else {
+      start--; 
+    }
+    if (i == 0) {
+      Node other_extract = utils::mkExtract(other, end, 0);
+      Node const_extract = utils::mkExtract(constant, end, 0);
+      Node bitwise_op = utils::mkNode(node.getKind(), const_extract, other_extract);
+      concat_children.push_back(bitwise_op);
+    }
 
+  }
+  Node result = utils::mkNode(kind::BITVECTOR_CONCAT, concat_children);
+  Debug("bv-rewrite") << "    =>" << result << std::endl;
+  return result;
+}
 
 // template<> inline
 // bool RewriteRule<>::applies(TNode node) {
index f6d138f5d541b8d7a0367dd721503eaeb4c640ce..44c498947315a8554fa4c48acd493309e3282447 100644 (file)
@@ -177,6 +177,11 @@ RewriteResponse TheoryBVRewriter::RewriteExtract(TNode node, bool preregister) {
     return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
   }
 
+  // if (RewriteRule<ExtractSignExtend>::applies(node)) {
+  //   resultNode = RewriteRule<ExtractSignExtend>::run<false>(node);
+  //   return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
+  // }
+
   if (RewriteRule<ExtractBitwise>::applies(node)) {
     resultNode = RewriteRule<ExtractBitwise>::run<false>(node);
     return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
@@ -223,14 +228,14 @@ RewriteResponse TheoryBVRewriter::RewriteAnd(TNode node, bool preregister){
   
   resultNode = LinearRewriteStrategy
     < RewriteRule<FlattenAssocCommut>,
-      RewriteRule<AndSimplify>
-      // RewriteRule<EvalAnd>,
-      // RewriteRule<BitwiseIdemp>,
-      // //RewriteRule<BitwiseSlice>, -> might need rw again
-      // RewriteRule<AndZero>,
-      // RewriteRule<AndOne> 
+      RewriteRule<AndSimplify>//,
+      //      RewriteRule<BitwiseSlicing>
       >::apply(node);
 
+  if (resultNode.getKind() != node.getKind()) {
+    return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
+  }
+  
   return RewriteResponse(REWRITE_DONE, resultNode); 
 }
 
@@ -239,8 +244,13 @@ RewriteResponse TheoryBVRewriter::RewriteOr(TNode node, bool preregister){
 
   resultNode = LinearRewriteStrategy
     < RewriteRule<FlattenAssocCommut>,
-      RewriteRule<OrSimplify>
+      RewriteRule<OrSimplify>//,
+      //      RewriteRule<BitwiseSlicing>
     >::apply(node);
+
+  if (resultNode.getKind() != node.getKind()) {
+    return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
+  }
   
   return RewriteResponse(REWRITE_DONE, resultNode); 
 }
@@ -251,7 +261,8 @@ RewriteResponse TheoryBVRewriter::RewriteXor(TNode node, bool preregister) {
   resultNode = LinearRewriteStrategy
     < RewriteRule<FlattenAssocCommut>, // flatten the expression 
       RewriteRule<XorSimplify>,        // simplify duplicates and constants
-      RewriteRule<XorZero>             // checks if the constant part is zero and eliminates it
+      RewriteRule<XorZero>//,            // checks if the constant part is zero and eliminates it
+      //      RewriteRule<BitwiseSlicing>
     >::apply(node);
 
   // this simplification introduces new terms and might require further
@@ -261,6 +272,10 @@ RewriteResponse TheoryBVRewriter::RewriteXor(TNode node, bool preregister) {
     return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
   }
 
+  if (resultNode.getKind() != node.getKind()) {
+    return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
+  }
+
   return RewriteResponse(REWRITE_DONE, resultNode); 
 }
 
@@ -301,7 +316,8 @@ RewriteResponse TheoryBVRewriter::RewriteMult(TNode node, bool preregister) {
 
   resultNode = LinearRewriteStrategy
     < RewriteRule<FlattenAssocCommut>, // flattens and sorts
-      RewriteRule<MultSimplify>        // multiplies constant part and checks for 0
+      RewriteRule<MultSimplify>,       // multiplies constant part and checks for 0
+      RewriteRule<MultPow2>            // replaces multiplication by a power of 2 by a shift
     >::apply(node);
 
   // only apply if every subterm was already rewritten 
@@ -317,7 +333,7 @@ RewriteResponse TheoryBVRewriter::RewriteMult(TNode node, bool preregister) {
   if(resultNode == node) {
     return RewriteResponse(REWRITE_DONE, resultNode); 
   }
-  return RewriteResponse(REWRITE_DONE, resultNode); 
+  return RewriteResponse(REWRITE_AGAIN_FULL, resultNode); 
 }
 
 RewriteResponse TheoryBVRewriter::RewritePlus(TNode node, bool preregister) {
index 174df03ab18a8ec5636d2a3d554983d659afcd2f..5847bac3eb066b9546dfb6082d251d00756652a4 100644 (file)
@@ -84,6 +84,7 @@ inline Node mkSortedNode(Kind kind, std::vector<Node>& children) {
 
 
 inline Node mkNode(Kind kind, std::vector<Node>& children) {
+  Assert (children.size() > 0); 
   if (children.size() == 1) {
     return children[0]; 
   }
@@ -133,6 +134,12 @@ inline Node mkXor(TNode node1, TNode node2) {
 }
 
 
+inline Node mkSignExtend(TNode node, unsigned ammount) {
+  NodeManager* nm = NodeManager::currentNM(); 
+  Node signExtendOp = nm->mkConst<BitVectorSignExtend>(BitVectorSignExtend(ammount));
+  return nm->mkNode(signExtendOp, node); 
+}
+
 inline Node mkExtract(TNode node, unsigned high, unsigned low) {
   Node extractOp = NodeManager::currentNM()->mkConst<BitVectorExtract>(BitVectorExtract(high, low));
   std::vector<Node> children;