Add rewriting rule for ranking benchmarks. (#1448)
authorMathias Preiner <mathias.preiner@gmail.com>
Thu, 21 Dec 2017 00:45:07 +0000 (16:45 -0800)
committerGitHub <noreply@github.com>
Thu, 21 Dec 2017 00:45:07 +0000 (16:45 -0800)
src/theory/bv/theory_bv_rewrite_rules.h
src/theory/bv/theory_bv_rewrite_rules_simplification.h
src/theory/bv/theory_bv_rewriter.cpp
src/theory/bv/theory_bv_utils.h

index 4d0f8033eea78b92fe2a8faf74e9ab21220c3a8e..1d2e7d9a399fa589dfe2bf9f4b81b36f8443665a 100644 (file)
@@ -179,7 +179,8 @@ enum RewriteRuleId {
   BBPlusNeg,
   UltPlusOne,
   ConcatToMult,
-  IsPowerOfTwo
+  IsPowerOfTwo,
+  MultSltMult,
 };
 
 inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
@@ -319,6 +320,7 @@ inline std::ostream& operator << (std::ostream& out, RewriteRuleId ruleId) {
   case UltPlusOne: out << "UltPlusOne"; return out;
   case ConcatToMult: out << "ConcatToMult"; return out;
   case IsPowerOfTwo: out << "IsPowerOfTwo"; return out;
+  case MultSltMult: out << "MultSltMult"; return out;
   default:
     Unreachable();
   }
@@ -545,6 +547,7 @@ struct AllRewriteRules {
   RewriteRule<ZeroExtendEqConst> rule125;
   RewriteRule<SignExtendUltConst> rule126;
   RewriteRule<ZeroExtendUltConst> rule127;
+  RewriteRule<MultSltMult> rule128;
 };
 
 template<> inline
index 24e5fb5e88e83a5be9b008e10f1c185759a98b1b..98a31189074705a9484827225bb5331f664eac58 100644 (file)
@@ -1431,9 +1431,153 @@ Node RewriteRule<IsPowerOfTwo>::apply(TNode node) {
   return x_eq_sh;
 }
 
+/**
+ * Rewrite
+ *   sign_extend(x+t,n) * sign_extend(a,m) < sign_extend(x,n) * sign_extend(a,m)
+ * to
+ *   (and
+ *    (not (= t zero))
+ *    (not (= a zero))
+ *    (= (bvslt (bvadd x t) x) (bvsgt a zero))
+ *   )
+ *
+ * Rewrite
+ *   zero_extend(x+t,n) * sign_extend(a,m) < zero_extend(x,n) * sign_extend(a,m)
+ * to
+ *   (and
+ *    (not (= t zero))
+ *    (not (= a zero))
+ *    (= (bvult (bvadd x t) x) (bvsgt a zero))
+ *   )
+ * where n and m are sufficiently big to not produce an overflow for
+ * the multiplications.
+ *
+ * These patterns occur in the quantified BV benchmark family 'ranking',
+ * where the BV engine struggles due to the high bit widths of the
+ * multiplication's operands.
+ */
+static std::tuple<Node, Node, bool>
+extract_ext_tuple(TNode node)
+{
+  TNode a = node[0];
+  TNode b = node[1];
+  for (unsigned i = 0; i < 2; ++i)
+  {
+    if (a.getKind() == kind::BITVECTOR_CONCAT
+        && b.getKind() == kind::BITVECTOR_SIGN_EXTEND
+        && a[0] == utils::mkZero(utils::getSize(a[0]))
+        && utils::getSize(a[1]) <= utils::getSize(a[0])
+        && utils::getSize(b[0]) <= utils::getSignExtendAmount(b))
+    {
+      return std::make_tuple(a[1], b[0], false);
+    }
+    else if (i == 0
+             && a.getKind() == kind::BITVECTOR_SIGN_EXTEND
+             && b.getKind() == kind::BITVECTOR_SIGN_EXTEND
+             && utils::getSize(a[0]) <= utils::getSignExtendAmount(a)
+             && utils::getSize(b[0]) <= utils::getSignExtendAmount(b))
+    {
+      return std::make_tuple(a[0], b[0], true);
+    }
+    std::swap(a, b);
+  }
+  return std::make_tuple(Node::null(), Node::null(), false);
+}
 
+template<> inline
+bool RewriteRule<MultSltMult>::applies(TNode node)
+{
+  if (node.getKind() != kind::BITVECTOR_SLT
+      || node[0].getKind() != kind::BITVECTOR_MULT
+      || node[1].getKind() != kind::BITVECTOR_MULT)
+    return false;
 
+  if (node[0].getNumChildren() > 2 || node[1].getNumChildren() > 2)
+    return false;
 
+  bool is_sext_l, is_sext_r;
+  TNode ml[2], mr[2];
+
+  std::tie(ml[0], ml[1], is_sext_l) = extract_ext_tuple(node[0]);
+  if (ml[0].isNull())
+    return false;
+
+  std::tie(mr[0], mr[1], is_sext_r) = extract_ext_tuple(node[1]);
+  if (mr[0].isNull())
+    return false;
+
+  if (is_sext_l != is_sext_r)
+    return false;
+
+  TNode addxt, x, a;
+  if (ml[0].getKind() == kind::BITVECTOR_PLUS)
+  {
+    addxt = ml[0];
+    a = ml[1];
+  }
+  else if (ml[1].getKind() == kind::BITVECTOR_PLUS)
+  {
+    addxt = ml[1];
+    a = ml[0];
+  }
+  else
+    return false;
+
+  if (addxt.getNumChildren() > 2)
+    return false;
+
+  if (mr[0] == a)
+  {
+    x = mr[1];
+  }
+  else if (mr[1] == a)
+  {
+    x = mr[0];
+  }
+  else
+    return false;
+
+  return (addxt[0] == x || addxt[1] == x);
 }
+
+template<> inline
+Node RewriteRule<MultSltMult>::apply(TNode node)
+{
+  bool is_sext;
+  TNode ml[2], mr[2];
+
+  std::tie(ml[0], ml[1], is_sext) = extract_ext_tuple(node[0]);
+  std::tie(mr[0], mr[1], std::ignore) = extract_ext_tuple(node[1]);
+
+  TNode addxt, x, t, a;
+  if (ml[0].getKind() == kind::BITVECTOR_PLUS)
+  {
+    addxt = ml[0];
+    a = ml[1];
+  }
+  else
+  {
+    Assert(ml[1].getKind() == kind::BITVECTOR_PLUS);
+    addxt = ml[1];
+    a = ml[0];
+  }
+
+  x = (mr[0] == a) ? mr[1] : mr[0];
+  t = (addxt[0] == x) ? addxt[1] : addxt[0];
+
+  NodeManager *nm = NodeManager::currentNM();
+  Node zero_t = utils::mkZero(utils::getSize(t));
+  Node zero_a = utils::mkZero(utils::getSize(a));
+
+  NodeBuilder<> nb(kind::AND);
+  Kind k = is_sext ? kind::BITVECTOR_SLT : kind::BITVECTOR_ULT;
+  nb << t.eqNode(zero_t).notNode();
+  nb << a.eqNode(zero_a).notNode();
+  nb << nm->mkNode(k, addxt, x)
+            .eqNode(nm->mkNode(kind::BITVECTOR_SGT, a, zero_a));
+  return nb.constructNode();
 }
-}
+
+}  // namespace bv
+}  // namespace theory
+}  // namespace CVC4
index bfaf517cc9733330c83ed5c6fae7a8b5f9f1a1fb..6bdd2ec28c161790aa1ed9c80a8f116998cc254a 100644 (file)
@@ -89,7 +89,8 @@ RewriteResponse TheoryBVRewriter::RewriteUltBv(TNode node, bool prerewrite) {
 
 RewriteResponse TheoryBVRewriter::RewriteSlt(TNode node, bool prerewrite){
   Node resultNode = LinearRewriteStrategy
-    < RewriteRule < EvalSlt >
+    < RewriteRule<EvalSlt>,
+      RewriteRule<MultSltMult>
        >::apply(node);
 
   return RewriteResponse(REWRITE_DONE, resultNode); 
index ea2dd4fc86c6b30dd2ae0646f44c08b3a7251c32..d9d1183af07792f63682e3bd07f8e5520abdf7df 100644 (file)
@@ -54,6 +54,11 @@ inline unsigned getSize(TNode node) {
   return node.getType().getBitVectorSize();
 }
 
+inline unsigned getSignExtendAmount(TNode node)
+{
+  return node.getOperator().getConst<BitVectorSignExtend>().signExtendAmount;
+}
+
 inline const bool getBit(TNode node, unsigned i) {
   Assert (i < utils::getSize(node) && 
           node.getKind() == kind::CONST_BITVECTOR);