Fix linearization for terms where the solve variable does not occur. (#1506)
authorMathias Preiner <mathias.preiner@gmail.com>
Wed, 10 Jan 2018 02:12:32 +0000 (18:12 -0800)
committerGitHub <noreply@github.com>
Wed, 10 Jan 2018 02:12:32 +0000 (18:12 -0800)
src/theory/quantifiers/ceg_t_instantiator.cpp
test/unit/theory/theory_quantifiers_bv_instantiator_white.h

index 2c71de666727c61bce5fa222abc0ce94ef09c3f9..275d7238d4bd279cd7578f88161ff5ada9c97df7 100644 (file)
@@ -1456,7 +1456,8 @@ static Node getPvCoeff(TNode pv, TNode n)
  *  pv * -(a * b * c)
  *
  * Returns the normalized node if the resulting term is linear w.r.t. pv and
- * a null node otherwise.
+ * a null node otherwise. If pv does not occur in children it returns a
+ * multiplication over children.
  */
 static Node normalizePvMult(
     TNode pv,
@@ -1518,13 +1519,21 @@ static Node normalizePvMult(
   {
     return zero;
   }
-  else if (coeff == bv::utils::mkOne(size_coeff))
+  Node result;
+  if (found_pv)
   {
-    return pv;
+    if (coeff == bv::utils::mkOne(size_coeff))
+    {
+      return pv;
+    }
+    result = nm->mkNode(BITVECTOR_MULT, pv, coeff);
+    contains_pv[result] = true;
+    result.setAttribute(is_linear, true);
+  }
+  else
+  {
+    result = coeff;
   }
-  Node result = nm->mkNode(BITVECTOR_MULT, pv, coeff);
-  contains_pv[result] = true;
-  result.setAttribute(is_linear, true);
   return result;
 }
 
@@ -1564,7 +1573,8 @@ static bool isLinearPlus(
  *  pv * (a - c) + b
  *
  * Returns the normalized node if the resulting term is linear w.r.t. pv and
- * a null node otherwise.
+ * a null node otherwise. If pv does not occur in children it returns an
+ * addition over children.
  */
 static Node normalizePvPlus(
     Node pv,
@@ -1609,6 +1619,7 @@ static Node normalizePvPlus(
     {
       Assert(isLinearPlus(nc, pv, contains_pv));
       Node coeff = getPvCoeff(pv, nc[0]);
+      Assert(!coeff.isNull());
       Node leaf = nc[1];
       if (neg)
       {
@@ -1622,22 +1633,23 @@ static Node normalizePvPlus(
     /* can't collect coefficients of 'pv' in 'cur' -> non-linear */
     return Node::null();
   }
-  Assert(nb_c.getNumChildren() > 0);
-
-  Node coeffs = (nb_c.getNumChildren() == 1) ? nb_c[0] : nb_c.constructNode();
-  coeffs = Rewriter::rewrite(coeffs);
+  Assert(nb_c.getNumChildren() > 0 || nb_l.getNumChildren() > 0);
 
-  std::vector<Node> mult_children = {pv, coeffs};
-  Node pv_mult_coeffs = normalizePvMult(pv, mult_children, contains_pv);
+  Node pv_mult_coeffs, result;
+  if (nb_c.getNumChildren() > 0)
+  {
+    Node coeffs = (nb_c.getNumChildren() == 1) ? nb_c[0] : nb_c.constructNode();
+    coeffs = Rewriter::rewrite(coeffs);
+    result = pv_mult_coeffs = normalizePvMult(pv, {pv, coeffs}, contains_pv);
+  }
 
   if (nb_l.getNumChildren() > 0)
   {
     Node leafs = (nb_l.getNumChildren() == 1) ? nb_l[0] : nb_l.constructNode();
     leafs = Rewriter::rewrite(leafs);
     Node zero = bv::utils::mkZero(bv::utils::getSize(pv));
-    Node result;
     /* pv * 0 + t --> t */
-    if (pv_mult_coeffs == zero)
+    if (pv_mult_coeffs.isNull() || pv_mult_coeffs == zero)
     {
       result = leafs;
     }
@@ -1647,9 +1659,9 @@ static Node normalizePvPlus(
       contains_pv[result] = true;
       result.setAttribute(is_linear, true);
     }
-    return result;
   }
-  return pv_mult_coeffs;
+  Assert(!result.isNull());
+  return result;
 }
 
 /**
index 1f3932be0fa415427bd3d990b4dbe7b2f4e9fefc..1e6578b275c89a44a4539cc3809c8e1294612aa4 100644 (file)
@@ -162,6 +162,10 @@ void BvInstantiatorWhite::testNormalizePvMult()
   Node norm_xx = normalizePvMult(x, {x, neg_x}, contains_x);
   TS_ASSERT(norm_xx.isNull());
 
+  /* nothing to normalize -> create a * a */
+  Node norm_aa = normalizePvMult(x, {a, a}, contains_x);
+  TS_ASSERT(norm_aa == Rewriter::rewrite(mkMult(a, a)));
+
   /* normalize x * a -> x * a */
   Node norm_xa = normalizePvMult(x, {x, a}, contains_x);
   TS_ASSERT(contains_x[norm_xa]);
@@ -258,6 +262,10 @@ void BvInstantiatorWhite::testNormalizePvPlus()
   Node norm_abx = normalizePvPlus(x, {a, mult_bx}, contains_x);
   TS_ASSERT(norm_abx.isNull());
 
+  /* nothing to normalize -> create a + a */
+  Node norm_aa = normalizePvPlus(x, {a, a}, contains_x);
+  TS_ASSERT(norm_aa == Rewriter::rewrite(mkPlus(a, a)));
+
   /* x + a -> x + a */
   Node norm_xa = normalizePvPlus(x, {x, a}, contains_x);
   TS_ASSERT(contains_x[norm_xa]);