Improvements to extended rewriter for Booleans and ITE (#1705)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 3 Apr 2018 01:03:16 +0000 (20:03 -0500)
committerGitHub <noreply@github.com>
Tue, 3 Apr 2018 01:03:16 +0000 (20:03 -0500)
src/options/quantifiers_options.toml
src/theory/quantifiers/extended_rewrite.cpp
src/theory/quantifiers/extended_rewrite.h
src/theory/quantifiers/term_util.cpp
src/theory/quantifiers/term_util.h

index 28a9e58a7a8ae621569e6723118e081103f7613e..f877143a233640b75fb719e85906e56bfdb8d8dc 100644 (file)
@@ -1085,6 +1085,14 @@ header = "options/quantifiers_options.h"
   default    = "false"
   help       = "enumerate a stream of solutions instead of terminating after the first one"
 
+[[option]]
+  name       = "sygusExtRew"
+  category   = "regular"
+  long       = "sygus-ext-rew"
+  type       = "bool"
+  default    = "true"
+  help       = "use extended rewriter for sygus"
+
 [[option]]
   name       = "cegisSample"
   category   = "regular"
index dd4fc86ba908ff49838337481d521062a7f0dbd7..756413b540a112b86bfff9df89a0f758d5fcb92e 100644 (file)
@@ -14,7 +14,9 @@
 
 #include "theory/quantifiers/extended_rewrite.h"
 
+#include "options/quantifiers_options.h"
 #include "theory/arith/arith_msum.h"
+#include "theory/bv/theory_bv_utils.h"
 #include "theory/datatypes/datatypes_rewriter.h"
 #include "theory/quantifiers/term_util.h"
 #include "theory/rewriter.h"
@@ -26,201 +28,176 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
+struct ExtRewriteAttributeId
+{
+};
+typedef expr::Attribute<ExtRewriteAttributeId, Node> ExtRewriteAttribute;
+
 ExtendedRewriter::ExtendedRewriter(bool aggr) : d_aggr(aggr)
 {
-  d_true = NodeManager::currentNM()->mkConst(true);
-  d_false = NodeManager::currentNM()->mkConst(false);
 }
-
-Node ExtendedRewriter::extendedRewritePullIte(Node n)
+void ExtendedRewriter::setCache(Node n, Node ret)
 {
-  // generalize this?
-  Assert(n.getNumChildren() == 2);
-  Assert(n.getType().isBoolean());
-  Assert(n.getMetaKind() != kind::metakind::PARAMETERIZED);
-  std::vector<Node> children;
-  for (unsigned i = 0; i < n.getNumChildren(); i++)
-  {
-    children.push_back(n[i]);
-  }
-  for (unsigned i = 0; i < 2; i++)
-  {
-    if (n[i].getKind() == kind::ITE)
-    {
-      for (unsigned j = 0; j < 2; j++)
-      {
-        children[i] = n[i][j + 1];
-        Node eqr = extendedRewrite(
-            NodeManager::currentNM()->mkNode(n.getKind(), children));
-        children[i] = n[i];
-        if (eqr.isConst())
-        {
-          std::vector<Node> new_children;
-          Kind new_k;
-          if (eqr == d_true)
-          {
-            new_k = kind::OR;
-            new_children.push_back(j == 0 ? n[i][0] : n[i][0].negate());
-          }
-          else
-          {
-            Assert(eqr == d_false);
-            new_k = kind::AND;
-            new_children.push_back(j == 0 ? n[i][0].negate() : n[i][0]);
-          }
-          children[i] = n[i][2 - j];
-          Node rem_eq = NodeManager::currentNM()->mkNode(n.getKind(), children);
-          children[i] = n[i];
-          new_children.push_back(rem_eq);
-          Node nc = NodeManager::currentNM()->mkNode(new_k, new_children);
-          Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to "
-                                 << nc << " by simple ITE pulling."
-                                 << std::endl;
-          return nc;
-        }
-      }
-    }
-  }
-  return Node::null();
+  ExtRewriteAttribute era;
+  n.setAttribute(era, ret);
 }
 
 Node ExtendedRewriter::extendedRewrite(Node n)
 {
   n = Rewriter::rewrite(n);
-  std::unordered_map<Node, Node, NodeHashFunction>::iterator it =
-      d_ext_rewrite_cache.find(n);
-  if (it != d_ext_rewrite_cache.end())
+  if (!options::sygusExtRew())
+  {
+    return n;
+  }
+
+  // has it already been computed?
+  if (n.hasAttribute(ExtRewriteAttribute()))
   {
-    return it->second;
+    return n.getAttribute(ExtRewriteAttribute());
   }
+
   Node ret = n;
+  NodeManager* nm = NodeManager::currentNM();
+
+  //--------------------pre-rewrite
+  Node pre_new_ret;
+  if (ret.getKind() == IMPLIES)
+  {
+    pre_new_ret = nm->mkNode(OR, ret[0].negate(), ret[1]);
+    debugExtendedRewrite(ret, pre_new_ret, "IMPLIES elim");
+  }
+  else if (ret.getKind() == XOR)
+  {
+    pre_new_ret = nm->mkNode(EQUAL, ret[0].negate(), ret[1]);
+    debugExtendedRewrite(ret, pre_new_ret, "XOR elim");
+  }
+  else if (ret.getKind() == NOT)
+  {
+    pre_new_ret = extendedRewriteNnf(ret);
+    debugExtendedRewrite(ret, pre_new_ret, "NNF");
+  }
+  if (!pre_new_ret.isNull())
+  {
+    ret = extendedRewrite(pre_new_ret);
+    Trace("q-ext-rewrite-debug") << "...ext-pre-rewrite : " << n << " -> "
+                                 << pre_new_ret << std::endl;
+    setCache(n, ret);
+    return ret;
+  }
+  //--------------------end pre-rewrite
+
+  //--------------------rewrite children
   if (n.getNumChildren() > 0)
   {
     std::vector<Node> children;
-    if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
+    if (n.getMetaKind() == metakind::PARAMETERIZED)
     {
       children.push_back(n.getOperator());
     }
+    Kind k = n.getKind();
     bool childChanged = false;
+    bool isNonAdditive = TermUtil::isNonAdditive(k);
     for (unsigned i = 0; i < n.getNumChildren(); i++)
     {
       Node nc = extendedRewrite(n[i]);
       childChanged = nc != n[i] || childChanged;
-      children.push_back(nc);
+      // If the operator is non-additive, do not consider duplicates
+      if (isNonAdditive
+          && std::find(children.begin(), children.end(), nc) != children.end())
+      {
+        childChanged = true;
+      }
+      else
+      {
+        children.push_back(nc);
+      }
     }
+    Assert(!children.empty());
     // Some commutative operators have rewriters that are agnostic to order,
     // thus, we sort here.
-    if (TermUtil::isComm(n.getKind()) && (d_aggr || children.size() <= 5))
+    if (TermUtil::isComm(k) && (d_aggr || children.size() <= 5))
     {
       childChanged = true;
       std::sort(children.begin(), children.end());
     }
     if (childChanged)
     {
-      ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
+      if (isNonAdditive && children.size() == 1)
+      {
+        // we may have subsumed children down to one
+        ret = children[0];
+      }
+      else
+      {
+        ret = nm->mkNode(k, children);
+      }
     }
   }
   ret = Rewriter::rewrite(ret);
+  //--------------------end rewrite children
+
+  // now, do extended rewrite
   Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret
                                << " (from " << n << ")" << std::endl;
-
   Node new_ret;
-  if (ret.getKind() == kind::EQUAL)
+
+  //---------------------- theory-independent post-rewriting
+  if (ret.getKind() == ITE)
   {
-    if (new_ret.isNull())
-    {
-      // simple ITE pulling
-      new_ret = extendedRewritePullIte(ret);
-    }
+    new_ret = extendedRewriteIte(ITE, ret);
   }
-  else if (ret.getKind() == kind::ITE)
+  else if (ret.getKind() == AND || ret.getKind() == OR)
   {
-    Assert(ret[1] != ret[2]);
-    if (ret[0].getKind() == NOT)
-    {
-      ret = NodeManager::currentNM()->mkNode(
-          kind::ITE, ret[0][0], ret[2], ret[1]);
-    }
-    if (ret[0].getKind() == kind::EQUAL)
-    {
-      // simple invariant ITE
-      for (unsigned i = 0; i < 2; i++)
-      {
-        if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i])
-        {
-          Trace("q-ext-rewrite")
-              << "sygus-extr : " << ret << " rewrites to " << ret[2]
-              << " due to simple invariant ITE." << std::endl;
-          new_ret = ret[2];
-          break;
-        }
-      }
-      // notice this is strictly more general than the above
-      if (new_ret.isNull())
-      {
-        // simple substitution
-        for (unsigned i = 0; i < 2; i++)
-        {
-          TNode r1 = ret[0][i];
-          TNode r2 = ret[0][1 - i];
-          if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
-          {
-            Node retn = ret[1].substitute(r1, r2);
-            if (retn != ret[1])
-            {
-              new_ret = NodeManager::currentNM()->mkNode(
-                  kind::ITE, ret[0], retn, ret[2]);
-              Trace("q-ext-rewrite")
-                  << "sygus-extr : " << ret << " rewrites to " << new_ret
-                  << " due to simple ITE substitution." << std::endl;
-            }
-          }
-        }
-      }
-    }
+    // all kinds are legal to substitute over : hence we give the empty map
+    std::map<Kind, bool> bcp_kinds;
+    new_ret = extendedRewriteBcp(AND, OR, NOT, bcp_kinds, ret);
+    debugExtendedRewrite(ret, new_ret, "Bool bcp");
   }
-  else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION
-           || ret.getKind() == INTS_MODULUS)
+  else if (ret.getKind() == EQUAL)
   {
-    // rewrite as though total
-    std::vector<Node> children;
-    bool all_const = true;
-    for (unsigned i = 0; i < ret.getNumChildren(); i++)
+    new_ret = extendedRewriteEqChain(EQUAL, AND, OR, NOT, ret);
+    debugExtendedRewrite(ret, new_ret, "Bool eq-chain simplify");
+  }
+  if (new_ret.isNull() && ret.getKind() != ITE)
+  {
+    // simple ITE pulling
+    new_ret = extendedRewritePullIte(ITE, ret);
+  }
+  //----------------------end theory-independent post-rewriting
+
+  //----------------------theory-specific post-rewriting
+  if (new_ret.isNull())
+  {
+    Node atom = ret.getKind() == NOT ? ret[0] : ret;
+    bool pol = ret.getKind() != NOT;
+    TheoryId tid = Theory::theoryOf(atom);
+    if (tid == THEORY_ARITH)
     {
-      if (ret[i].isConst())
-      {
-        children.push_back(ret[i]);
-      }
-      else
-      {
-        all_const = false;
-        break;
-      }
+      new_ret = extendedRewriteArith(atom, pol);
     }
-    if (all_const)
+    // add back negation if not processed
+    if (!pol && !new_ret.isNull())
     {
-      Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
-                                              : (ret.getKind() == INTS_DIVISION
-                                                     ? INTS_DIVISION_TOTAL
-                                                     : INTS_MODULUS_TOTAL));
-      new_ret = NodeManager::currentNM()->mkNode(new_k, children);
-      Trace("q-ext-rewrite")
-          << "sygus-extr : " << ret << " rewrites to " << new_ret
-          << " due to total interpretation." << std::endl;
+      new_ret = new_ret.negate();
     }
   }
-  // more expensive rewrites
+  //----------------------end theory-specific post-rewriting
+
+  //----------------------aggressive rewrites
   if (new_ret.isNull() && d_aggr)
   {
     new_ret = extendedRewriteAggr(ret);
   }
+  //----------------------end aggressive rewrites
 
-  d_ext_rewrite_cache[n] = ret;
+  setCache(n, ret);
   if (!new_ret.isNull())
   {
     ret = extendedRewrite(new_ret);
   }
-  d_ext_rewrite_cache[n] = ret;
+  Trace("q-ext-rewrite-debug") << "...ext-rewrite : " << n << " -> " << ret
+                               << std::endl;
+  setCache(n, ret);
   return ret;
 }
 
@@ -234,6 +211,8 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
   if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal())
       || ret_atom.getKind() == GEQ)
   {
+    // ITE term removal in polynomials
+    // e.g. ite( x=0, x, y ) = x+1 ---> ( x=0 ^ y = x+1 )
     Trace("q-ext-rewrite-debug2")
         << "Compute monomial sum " << ret_atom << std::endl;
     // compute monomial sum
@@ -255,7 +234,7 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
             Trace("q-ext-rewrite-debug")
                 << "  have ITE relation, solved form : " << veq << std::endl;
             // try pulling ITE
-            new_ret = extendedRewritePullIte(veq);
+            new_ret = extendedRewritePullIte(ITE, veq);
             if (!new_ret.isNull())
             {
               if (!polarity)
@@ -279,10 +258,781 @@ Node ExtendedRewriter::extendedRewriteAggr(Node n)
           << "  failed to get monomial sum of " << n << std::endl;
     }
   }
-  // TODO (#1599) : conditional rewriting, condition merging
+  // TODO (#1706) : conditional rewriting, condition merging
+  return new_ret;
+}
+
+Node ExtendedRewriter::extendedRewriteIte(Kind itek, Node n, bool full)
+{
+  Assert(n.getKind() == itek);
+  Assert(n[1] != n[2]);
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  Trace("ext-rew-ite") << "Rewrite ITE : " << n << std::endl;
+
+  Node flip_cond;
+  if (n[0].getKind() == NOT)
+  {
+    flip_cond = n[0][0];
+  }
+  else if (n[0].getKind() == OR)
+  {
+    // a | b ---> ~( ~a & ~b )
+    flip_cond = TermUtil::simpleNegate(n[0]);
+  }
+  if (!flip_cond.isNull())
+  {
+    Node new_ret = nm->mkNode(ITE, flip_cond, n[2], n[1]);
+    // only print debug trace if full=true
+    if (full)
+    {
+      debugExtendedRewrite(n, new_ret, "ITE flip");
+    }
+    return new_ret;
+  }
+
+  // get entailed equalities in the condition
+  std::vector<Node> eq_conds;
+  Kind ck = n[0].getKind();
+  if (ck == EQUAL)
+  {
+    eq_conds.push_back(n[0]);
+  }
+  else if (ck == AND)
+  {
+    for (const Node& cn : n[0])
+    {
+      if (cn.getKind() == EQUAL)
+      {
+        eq_conds.push_back(cn);
+      }
+    }
+  }
+
+  Node new_ret;
+  Node b;
+  Node e;
+  Node t1 = n[1];
+  Node t2 = n[2];
+  std::stringstream ss_reason;
+
+  for (const Node& eq : eq_conds)
+  {
+    // simple invariant ITE
+    for (unsigned i = 0; i <= 1; i++)
+    {
+      // ite( x = y ^ C, y, x ) ---> x
+      // this is subsumed by the rewrites below
+      if (t2 == eq[i] && t1 == eq[1 - i])
+      {
+        new_ret = t2;
+        ss_reason << "ITE simple rev subs";
+        break;
+      }
+    }
+    if (!new_ret.isNull())
+    {
+      break;
+    }
+  }
+
+  if (new_ret.isNull() && d_aggr)
+  {
+    // If x is less than t based on an ordering, then we use { x -> t } as a
+    // substitution to the children of ite( x = t ^ C, s, t ) below.
+    std::vector<Node> vars;
+    std::vector<Node> subs;
+    for (const Node& eq : eq_conds)
+    {
+      inferSubstitution(eq, vars, subs);
+    }
+
+    if (!vars.empty())
+    {
+      // reverse substitution to opposite child
+      // r{ x -> t } = s  implies  ite( x=t ^ C, s, r ) ---> r
+      Node nn =
+          t2.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+      if (nn != t2)
+      {
+        nn = Rewriter::rewrite(nn);
+        if (nn == t1)
+        {
+          new_ret = t2;
+          ss_reason << "ITE rev subs";
+        }
+      }
+
+      // ite( x=t ^ C, s, r ) ---> ite( x=t ^ C, s{ x -> t }, r )
+      nn = t1.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+      if (nn != t1)
+      {
+        // If full=false, then we've duplicated a term u in the children of n.
+        // For example, when ITE pulling, we have n is of the form:
+        //   ite( C, f( u, t1 ), f( u, t2 ) )
+        // We must show that at least one copy of u dissappears in this case.
+        nn = Rewriter::rewrite(nn);
+        if (nn == t2)
+        {
+          new_ret = nn;
+          ss_reason << "ITE subs invariant";
+        }
+        else if (full || nn.isConst())
+        {
+          new_ret = nm->mkNode(itek, n[0], nn, t2);
+          ss_reason << "ITE subs";
+        }
+      }
+    }
+  }
+
+  // only print debug trace if full=true
+  if (!new_ret.isNull() && full)
+  {
+    debugExtendedRewrite(n, new_ret, ss_reason.str().c_str());
+  }
+
+  return new_ret;
+}
+
+Node ExtendedRewriter::extendedRewritePullIte(Kind itek, Node n)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  TypeNode tn = n.getType();
+  std::vector<Node> children;
+  bool hasOp = (n.getMetaKind() == metakind::PARAMETERIZED);
+  if (hasOp)
+  {
+    children.push_back(n.getOperator());
+  }
+  unsigned nchildren = n.getNumChildren();
+  for (unsigned i = 0; i < nchildren; i++)
+  {
+    children.push_back(n[i]);
+  }
+  std::map<unsigned, std::map<unsigned, Node> > ite_c;
+  for (unsigned i = 0; i < nchildren; i++)
+  {
+    if (n[i].getKind() == itek)
+    {
+      unsigned ii = hasOp ? i + 1 : i;
+      for (unsigned j = 0; j < 2; j++)
+      {
+        children[ii] = n[i][j + 1];
+        Node pull = nm->mkNode(n.getKind(), children);
+        Node pullr = Rewriter::rewrite(pull);
+        children[ii] = n[i];
+        ite_c[i][j] = pullr;
+      }
+      if (ite_c[i][0] == ite_c[i][1])
+      {
+        // ITE dual invariance
+        // f( t1..s1..tn ) ---> t  and  f( t1..s2..tn ) ---> t implies
+        // f( t1..ite( A, s1, s2 )..tn ) ---> t
+        debugExtendedRewrite(n, ite_c[i][0], "ITE dual invariant");
+        return ite_c[i][0];
+      }
+      else if (d_aggr)
+      {
+        for (unsigned j = 0; j < 2; j++)
+        {
+          Node pullr = ite_c[i][j];
+          if (pullr.isConst() || pullr == n[i][j + 1])
+          {
+            // ITE single child elimination
+            // f( t1..s1..tn ) ---> t  where t is a constant or s1 itself
+            // implies
+            // f( t1..ite( A, s1, s2 )..tn ) ---> ite( A, t, f( t1..s2..tn ) )
+            Node new_ret;
+            if (tn.isBoolean())
+            {
+              // remove false/true child immediately
+              bool pol = pullr.getConst<bool>();
+              std::vector<Node> new_children;
+              new_children.push_back((j == 0) == pol ? n[i][0]
+                                                     : n[i][0].negate());
+              new_children.push_back(ite_c[i][1 - j]);
+              new_ret = nm->mkNode(pol ? OR : AND, new_children);
+              debugExtendedRewrite(n, new_ret, "ITE Bool single elim");
+            }
+            else
+            {
+              new_ret = nm->mkNode(itek, n[i][0], ite_c[i][0], ite_c[i][1]);
+              debugExtendedRewrite(n, new_ret, "ITE single elim");
+            }
+            return new_ret;
+          }
+        }
+      }
+    }
+  }
+
+  for (std::pair<const unsigned, std::map<unsigned, Node> >& ip : ite_c)
+  {
+    Node nite = n[ip.first];
+    Assert(nite.getKind() == itek);
+    // now, simply pull the ITE and try ITE rewrites
+    Node pull_ite = nm->mkNode(itek, nite[0], ip.second[0], ip.second[1]);
+    pull_ite = Rewriter::rewrite(pull_ite);
+    if (pull_ite.getKind() == ITE)
+    {
+      Node new_pull_ite = extendedRewriteIte(itek, pull_ite, false);
+      if (!new_pull_ite.isNull())
+      {
+        debugExtendedRewrite(n, new_pull_ite, "ITE pull rewrite");
+        return new_pull_ite;
+      }
+    }
+    else
+    {
+      // A general rewrite could eliminate the ITE by pulling.
+      // An example is:
+      //   ~( ite( C, ~x, ~ite( C, y, x ) ) ) --->
+      //   ite( C, ~~x, ite( C, y, x ) ) --->
+      //   x
+      // where ~ is bitvector negation.
+      debugExtendedRewrite(n, pull_ite, "ITE pull basic elim");
+      return pull_ite;
+    }
+  }
+
+  return Node::null();
+}
+
+Node ExtendedRewriter::extendedRewriteNnf(Node ret)
+{
+  Assert(ret.getKind() == NOT);
+
+  Kind nk = ret[0].getKind();
+  bool neg_ch = false;
+  bool neg_ch_1 = false;
+  if (nk == AND || nk == OR)
+  {
+    neg_ch = true;
+    nk = nk == AND ? OR : AND;
+  }
+  else if (nk == IMPLIES)
+  {
+    neg_ch = true;
+    neg_ch_1 = true;
+    nk = AND;
+  }
+  else if (nk == ITE)
+  {
+    neg_ch = true;
+    neg_ch_1 = true;
+  }
+  else if (nk == XOR)
+  {
+    nk = EQUAL;
+  }
+  else if (nk == EQUAL && ret[0][0].getType().isBoolean())
+  {
+    neg_ch_1 = true;
+  }
+  else
+  {
+    return Node::null();
+  }
+
+  std::vector<Node> new_children;
+  for (unsigned i = 0, nchild = ret[0].getNumChildren(); i < nchild; i++)
+  {
+    Node c = ret[0][i];
+    c = (i == 0 ? neg_ch_1 : false) != neg_ch ? c.negate() : c;
+    new_children.push_back(c);
+  }
+  return NodeManager::currentNM()->mkNode(nk, new_children);
+}
+
+Node ExtendedRewriter::extendedRewriteBcp(
+    Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node ret)
+{
+  Kind k = ret.getKind();
+  Assert(k == andk || k == ork);
+  Trace("ext-rew-bcp") << "BCP: **** INPUT: " << ret << std::endl;
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  TypeNode tn = ret.getType();
+  Node truen = TermUtil::mkTypeMaxValue(tn);
+  Node falsen = TermUtil::mkTypeValue(tn, 0);
+
+  // terms to process
+  std::vector<Node> to_process;
+  for (const Node& cn : ret)
+  {
+    to_process.push_back(cn);
+  }
+  // the processing terms
+  std::vector<Node> clauses;
+  // the terms we have propagated information to
+  std::unordered_set<Node, NodeHashFunction> prop_clauses;
+  // the assignment
+  std::map<Node, Node> assign;
+  std::vector<Node> avars;
+  std::vector<Node> asubs;
+
+  Kind ok = k == andk ? ork : andk;
+  // global polarity : when k=ork, everything is negated
+  bool gpol = k == andk;
+
+  do
+  {
+    // process the current nodes
+    while (!to_process.empty())
+    {
+      std::vector<Node> new_to_process;
+      for (const Node& cn : to_process)
+      {
+        Trace("ext-rew-bcp-debug") << "process " << cn << std::endl;
+        Kind cnk = cn.getKind();
+        bool pol = cnk != notk;
+        Node cln = cnk == notk ? cn[0] : cn;
+        Assert(cln.getKind() != notk);
+        if ((pol && cln.getKind() == k) || (!pol && cln.getKind() == ok))
+        {
+          // flatten
+          for (const Node& ccln : cln)
+          {
+            Node lccln = pol ? ccln : TermUtil::mkNegate(notk, ccln);
+            new_to_process.push_back(lccln);
+          }
+        }
+        else
+        {
+          // add it to the assignment
+          Node val = gpol == pol ? truen : falsen;
+          std::map<Node, Node>::iterator it = assign.find(cln);
+          Trace("ext-rew-bcp") << "BCP: assign " << cln << " -> " << val
+                               << std::endl;
+          if (it != assign.end())
+          {
+            if (val != it->second)
+            {
+              Trace("ext-rew-bcp") << "BCP: conflict!" << std::endl;
+              // a conflicting assignment: we are done
+              return gpol ? falsen : truen;
+            }
+          }
+          else
+          {
+            assign[cln] = val;
+            avars.push_back(cln);
+            asubs.push_back(val);
+          }
+
+          // also, treat it as clause if possible
+          if (cln.getNumChildren() > 0
+              & (bcp_kinds.empty()
+                 || bcp_kinds.find(cln.getKind()) != bcp_kinds.end()))
+          {
+            if (std::find(clauses.begin(), clauses.end(), cn) == clauses.end()
+                && prop_clauses.find(cn) == prop_clauses.end())
+            {
+              Trace("ext-rew-bcp") << "BCP: new clause: " << cn << std::endl;
+              clauses.push_back(cn);
+            }
+          }
+        }
+      }
+      to_process.clear();
+      to_process.insert(
+          to_process.end(), new_to_process.begin(), new_to_process.end());
+    }
+
+    // apply substitution to all subterms of clauses
+    std::vector<Node> new_clauses;
+    for (const Node& c : clauses)
+    {
+      bool cpol = c.getKind() != notk;
+      Node ca = c.getKind() == notk ? c[0] : c;
+      bool childChanged = false;
+      std::vector<Node> ccs_children;
+      for (const Node& cc : ca)
+      {
+        Node ccs = cc;
+        if (bcp_kinds.empty())
+        {
+          Trace("ext-rew-bcp-debug") << "...do ordinary substitute"
+                                     << std::endl;
+          ccs = cc.substitute(
+              avars.begin(), avars.end(), asubs.begin(), asubs.end());
+        }
+        else
+        {
+          Trace("ext-rew-bcp-debug") << "...do partial substitute" << std::endl;
+          // substitution is only applicable to compatible kinds
+          ccs = partialSubstitute(ccs, assign, bcp_kinds);
+        }
+        childChanged = childChanged || ccs != cc;
+        ccs_children.push_back(ccs);
+      }
+      if (childChanged)
+      {
+        if (ca.getMetaKind() == metakind::PARAMETERIZED)
+        {
+          ccs_children.insert(ccs_children.begin(), ca.getOperator());
+        }
+        Node ccs = nm->mkNode(ca.getKind(), ccs_children);
+        ccs = cpol ? ccs : TermUtil::mkNegate(notk, ccs);
+        Trace("ext-rew-bcp") << "BCP: propagated " << c << " -> " << ccs
+                             << std::endl;
+        ccs = Rewriter::rewrite(ccs);
+        Trace("ext-rew-bcp") << "BCP: rewritten to " << ccs << std::endl;
+        to_process.push_back(ccs);
+        // store this as a node that propagation touched. This marks c so that
+        // it will not be included in the final construction.
+        prop_clauses.insert(ca);
+      }
+      else
+      {
+        new_clauses.push_back(c);
+      }
+    }
+    clauses.clear();
+    clauses.insert(clauses.end(), new_clauses.begin(), new_clauses.end());
+  } while (!to_process.empty());
+
+  // remake the node
+  if (!prop_clauses.empty())
+  {
+    std::vector<Node> children;
+    for (std::pair<const Node, Node>& l : assign)
+    {
+      Node a = l.first;
+      // if propagation did not touch a
+      if (prop_clauses.find(a) == prop_clauses.end())
+      {
+        Assert(l.second == truen || l.second == falsen);
+        Node ln = (l.second == truen) == gpol ? a : TermUtil::mkNegate(notk, a);
+        children.push_back(ln);
+      }
+    }
+    Node new_ret = children.size() == 1 ? children[0] : nm->mkNode(k, children);
+    Trace("ext-rew-bcp") << "BCP: **** OUTPUT: " << new_ret << std::endl;
+    return new_ret;
+  }
+
+  return Node::null();
+}
+
+Node ExtendedRewriter::extendedRewriteEqChain(
+    Kind eqk, Kind andk, Kind ork, Kind notk, Node ret, bool isXor)
+{
+  Assert(ret.getKind() == eqk);
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  TypeNode tn = ret[0].getType();
+
+  // sort/cancelling for Boolean EQUAL/XOR-chains
+  Trace("ext-rew-eqchain") << "Eq-Chain : " << ret << std::endl;
+
+  // get the children on either side
+  bool gpol = true;
+  std::vector<Node> children;
+  for (unsigned r = 0, size = ret.getNumChildren(); r < size; r++)
+  {
+    Node curr = ret[r];
+    // assume, if necessary, right associative
+    while (curr.getKind() == eqk && curr[0].getType() == tn)
+    {
+      children.push_back(curr[0]);
+      curr = curr[1];
+    }
+    children.push_back(curr);
+  }
+
+  std::map<Node, bool> cstatus;
+  // add children to status
+  for (const Node& c : children)
+  {
+    Node a = c;
+    if (a.getKind() == notk)
+    {
+      gpol = !gpol;
+      a = a[0];
+    }
+    Trace("ext-rew-eqchain") << "...child : " << a << std::endl;
+    std::map<Node, bool>::iterator itc = cstatus.find(a);
+    if (itc == cstatus.end())
+    {
+      cstatus[a] = true;
+    }
+    else
+    {
+      // cancels
+      cstatus.erase(a);
+      if (isXor)
+      {
+        gpol = !gpol;
+      }
+    }
+  }
+  Trace("ext-rew-eqchain") << "Global polarity : " << gpol << std::endl;
+
+  if (cstatus.empty())
+  {
+    return TermUtil::mkTypeConst(tn, gpol);
+  }
+
+  children.clear();
+
+  // cancel AND/OR children if possible
+  for (std::pair<const Node, bool>& cp : cstatus)
+  {
+    if (cp.second)
+    {
+      Node c = cp.first;
+      Kind ck = c.getKind();
+      if (ck == andk || ck == ork)
+      {
+        for (unsigned j = 0, nchild = c.getNumChildren(); j < nchild; j++)
+        {
+          Node cl = c[j];
+          Node ca = cl.getKind() == notk ? cl[0] : cl;
+          bool capol = cl.getKind() != notk;
+          // if this already exists as a child of the equality chain
+          std::map<Node, bool>::iterator itc = cstatus.find(ca);
+          if (itc != cstatus.end() && itc->second)
+          {
+            // cancel it
+            cstatus[ca] = false;
+            cstatus[c] = false;
+            // make new child
+            // x = ( y | ~x ) ---> y & x
+            // x = ( y | x ) ---> ~y | x
+            // x = ( y & x ) ---> y | ~x
+            // x = ( y & ~x ) ---> ~y & ~x
+            std::vector<Node> new_children;
+            for (unsigned k = 0, nchild = c.getNumChildren(); k < nchild; k++)
+            {
+              if (j != k)
+              {
+                new_children.push_back(c[k]);
+              }
+            }
+            Node nc[2];
+            nc[0] = c[j];
+            nc[1] = new_children.size() == 1 ? new_children[0]
+                                             : nm->mkNode(ck, new_children);
+            // negate the proper child
+            unsigned nindex = (ck == andk) == capol ? 0 : 1;
+            nc[nindex] = TermUtil::mkNegate(notk, nc[nindex]);
+            Kind nk = capol ? ork : andk;
+            // store as new child
+            children.push_back(nm->mkNode(nk, nc[0], nc[1]));
+            if (isXor)
+            {
+              gpol = !gpol;
+            }
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // sorted right associative chain
+  bool has_const = false;
+  unsigned const_index = 0;
+  for (std::pair<const Node, bool>& cp : cstatus)
+  {
+    if (cp.second)
+    {
+      if (cp.first.isConst())
+      {
+        has_const = true;
+        const_index = children.size();
+      }
+      children.push_back(cp.first);
+    }
+  }
+  std::sort(children.begin(), children.end());
+
+  Node new_ret;
+  if (!gpol)
+  {
+    // negate the constant child if it exists
+    unsigned nindex = has_const ? const_index : 0;
+    children[nindex] = TermUtil::mkNegate(notk, children[nindex]);
+  }
+  new_ret = children.back();
+  unsigned index = children.size() - 1;
+  while (index > 0)
+  {
+    index--;
+    new_ret = nm->mkNode(eqk, children[index], new_ret);
+  }
+  new_ret = Rewriter::rewrite(new_ret);
+  if (new_ret != ret)
+  {
+    return new_ret;
+  }
+  return Node::null();
+}
+
+Node ExtendedRewriter::partialSubstitute(Node n,
+                                         std::map<Node, Node>& assign,
+                                         std::map<Kind, bool>& rkinds)
+{
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      std::map<Node, Node>::iterator it = assign.find(cur);
+      if (it != assign.end())
+      {
+        visited[cur] = it->second;
+      }
+      else
+      {
+        // can only recurse on these kinds
+        Kind k = cur.getKind();
+        if (rkinds.find(k) != rkinds.end())
+        {
+          visited[cur] = Node::null();
+          visit.push_back(cur);
+          for (const Node& cn : cur)
+          {
+            visit.push_back(cn);
+          }
+        }
+        else
+        {
+          visited[cur] = cur;
+        }
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      bool childChanged = false;
+      std::vector<Node> children;
+      if (cur.getMetaKind() == metakind::PARAMETERIZED)
+      {
+        children.push_back(cur.getOperator());
+      }
+      for (const Node& cn : cur)
+      {
+        it = visited.find(cn);
+        Assert(it != visited.end());
+        Assert(!it->second.isNull());
+        childChanged = childChanged || cn != it->second;
+        children.push_back(it->second);
+      }
+      if (childChanged)
+      {
+        ret = NodeManager::currentNM()->mkNode(cur.getKind(), children);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
+Node ExtendedRewriter::solveEquality(Node n)
+{
+  // TODO (#1706) : implement
+  Assert(n.getKind() == EQUAL);
+
+  return Node::null();
+}
+
+bool ExtendedRewriter::inferSubstitution(Node n,
+                                         std::vector<Node>& vars,
+                                         std::vector<Node>& subs)
+{
+  if (n.getKind() == EQUAL)
+  {
+    // see if it can be put into form x = y
+    Node slv_eq = solveEquality(n);
+    if (!slv_eq.isNull())
+    {
+      n = slv_eq;
+    }
+    for (unsigned i = 0; i < 2; i++)
+    {
+      TNode r1 = n[i];
+      TNode r2 = n[1 - i];
+      if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst()))
+      {
+        // TODO (#1706) : union find
+        if (std::find(vars.begin(), vars.end(), r1) == vars.end())
+        {
+          vars.push_back(r1);
+          subs.push_back(r2);
+          return true;
+        }
+      }
+    }
+  }
+  return false;
+}
+
+Node ExtendedRewriter::extendedRewriteArith(Node ret, bool& pol)
+{
+  Kind k = ret.getKind();
+  NodeManager* nm = NodeManager::currentNM();
+  Node new_ret;
+  if (k == DIVISION || k == INTS_DIVISION || k == INTS_MODULUS)
+  {
+    // rewrite as though total
+    std::vector<Node> children;
+    bool all_const = true;
+    for (unsigned i = 0, size = ret.getNumChildren(); i < size; i++)
+    {
+      if (ret[i].isConst())
+      {
+        children.push_back(ret[i]);
+      }
+      else
+      {
+        all_const = false;
+        break;
+      }
+    }
+    if (all_const)
+    {
+      Kind new_k = (ret.getKind() == DIVISION ? DIVISION_TOTAL
+                                              : (ret.getKind() == INTS_DIVISION
+                                                     ? INTS_DIVISION_TOTAL
+                                                     : INTS_MODULUS_TOTAL));
+      new_ret = nm->mkNode(new_k, children);
+      debugExtendedRewrite(ret, new_ret, "total-interpretation");
+    }
+  }
   return new_ret;
 }
 
+void ExtendedRewriter::debugExtendedRewrite(Node n,
+                                            Node ret,
+                                            const char* c) const
+{
+  if (Trace.isOn("q-ext-rewrite"))
+  {
+    if (!ret.isNull())
+    {
+      Trace("q-ext-rewrite-apply") << "sygus-extr : apply " << c << std::endl;
+      Trace("q-ext-rewrite") << "sygus-extr : " << c << " : " << n
+                             << " rewrites to " << ret << std::endl;
+    }
+  }
+}
+
 } /* CVC4::theory::quantifiers namespace */
 } /* CVC4::theory namespace */
 } /* CVC4 namespace */
index 25d710a6b05c2e661cef51d55ceb7e0f66cc7962..2daa42b18725162cada91be979a43b96f88f6edd 100644 (file)
@@ -35,10 +35,14 @@ namespace quantifiers {
  *
  * This class extended the standard techniques for rewriting
  * with techniques, including but not limited to:
- * - ITE branch merging,
+ * - Redundant child elimination,
+ * - Sorting children of commutative operators,
+ * - Boolean constraint propagation,
+ * - Equality chain normalization,
+ * - Negation normal form,
+ * - Simple ITE pulling,
  * - ITE conditional variable elimination,
- * - ITE condition subsumption, and
- * - Aggressive rewriting for string equalities.
+ * - ITE condition subsumption.
  */
 class ExtendedRewriter
 {
@@ -60,21 +64,128 @@ class ExtendedRewriter
    * may be applied as a preprocessing step.
    */
   bool d_aggr;
-  /** true and false nodes */
-  Node d_true;
-  Node d_false;
-  /** cache for extendedRewrite */
-  std::unordered_map<Node, Node, NodeHashFunction> d_ext_rewrite_cache;
-  /** pull ITE
-   * Do simple ITE pulling, e.g.:
-   *   C2 --->^E false
-   * implies:
-   *  ite( C, C1, C2 ) --->^E  C ^ C1
-   * where ---->^E denotes extended rewriting.
+  /** cache that the extended rewritten form of n is ret */
+  void setCache(Node n, Node ret);
+
+  //--------------------------------------generic utilities
+  /** Rewrite ITE, for example:
+   *
+   * ite( ~C, s, t ) ---> ite( C, t, s )
+   * ite( A or B, s, t ) ---> ite( ~A and ~B, t, s )
+   * ite( x = c, x, t ) --> ite( x = c, c, t )
+   * t * { x -> c } = s => ite( x = c, s, t ) ---> t
+   *
+   * The parameter "full" indicates an effort level that this rewrite will
+   * take. If full is false, then we do only perform rewrites that
+   * strictly decrease the term size of n.
+   */
+  Node extendedRewriteIte(Kind itek, Node n, bool full = true);
+  /** Pull ITE, for example:
+   *
+   *   D=C2 ---> false
+   *     implies
+   *   D=ite( C, C1, C2 ) --->  C ^ D=C1
+   *
+   *   f(t,t1) --> s  and  f(t,t2)---> s
+   *     implies
+   *   f(t,ite(C,t1,t2)) ---> s
+   *
+   * If this function returns a non-null node ret, then n ---> ret.
+   */
+  Node extendedRewritePullIte(Kind itek, Node n);
+  /** Negation Normal Form (NNF), for example:
+   *
+   *   ~( A & B ) ---> ( ~ A | ~B )
+   *   ~( ite( A, B, C ) ---> ite( A, ~B, ~C )
+   *
+   * If this function returns a non-null node ret, then n ---> ret.
+   */
+  Node extendedRewriteNnf(Node n);
+  /** (type-independent) Boolean constraint propagation, for example:
+   *
+   *   ~A & ( B V A ) ---> ~A & B
+   *   A & ( B = ( A V C ) ) ---> A & B
+   *
+   * This function takes as arguments the kinds that specify AND, OR, and NOT.
+   * It additionally takes as argument a map bcp_kinds. If this map is
+   * non-empty, then all terms that have a Kind that is *not* in this map should
+   * be treated as immutable. This is for instance to prevent propagation
+   * beneath illegal terms. As an example:
+   *   (bvand A (bvor A B)) is equivalent to (bvand A (bvor 1...1 B)), but
+   *   (bvand A (bvplus A B)) is not equivalent to (bvand A (bvplus 1..1 B)),
+   * hence, when using this function to do BCP for bit-vectors, we have that
+   * BITVECTOR_AND is a bcp_kind, but BITVECTOR_PLUS is not.
+   *
+   * If this function returns a non-null node ret, then n ---> ret.
+   */
+  Node extendedRewriteBcp(
+      Kind andk, Kind ork, Kind notk, std::map<Kind, bool>& bcp_kinds, Node n);
+  /** (type-independent) Equality chain rewriting, for example:
+   *
+   *   A = ( A = B ) ---> B
+   *   ( A = D ) = ( C = B ) ---> A = ( B = ( C = D ) )
+   *   A = ( A & B ) ---> ~A | B
+   *
+   * If this function returns a non-null node ret, then n ---> ret.
+   * This function takes as arguments the kinds that specify EQUAL, AND, OR,
+   * and NOT. If the flag isXor is true, the eqk is treated as XOR.
+   */
+  Node extendedRewriteEqChain(
+      Kind eqk, Kind andk, Kind ork, Kind notk, Node n, bool isXor = false);
+  /** extended rewrite aggressive
+   *
+   * All aggressive rewriting techniques (those that should be prioritized
+   * at a lower level) go in this function.
    */
-  Node extendedRewritePullIte(Node n);
-  /** extended rewrite aggressive */
   Node extendedRewriteAggr(Node n);
+  /** Decompose right associative chain
+   *
+   * For term f( ... f( f( base, tn ), t{n-1} ) ... t1 ), returns term base, and
+   * appends t1...tn to children.
+   */
+  Node decomposeRightAssocChain(Kind k, Node n, std::vector<Node>& children);
+  /** Make right associative chain
+   *
+   * Sorts children to obtain list { tn...t1 }, and returns the term
+   * f( ... f( f( base, tn ), t{n-1} ) ... t1 ).
+   */
+  Node mkRightAssocChain(Kind k, Node base, std::vector<Node>& children);
+  /** Partial substitute
+   *
+   * Applies the substitution specified by assign to n, recursing only beneath
+   * terms whose Kind appears in rec_kinds.
+   */
+  Node partialSubstitute(Node n,
+                         std::map<Node, Node>& assign,
+                         std::map<Kind, bool>& rkinds);
+  /** solve equality
+   *
+   * If this function returns a non-null node n', then n' is equivalent to n
+   * and is of the form that can be used by inferSubstitution below.
+   */
+  Node solveEquality(Node n);
+  /** infer substitution
+   *
+   * If n is an equality of the form x = t, where t is either:
+   * (1) a constant, or
+   * (2) a variable y such that x < y based on an ordering,
+   * then this method adds x to vars and y to subs and return true, otherwise
+   * it returns false.
+   */
+  bool inferSubstitution(Node n,
+                         std::vector<Node>& vars,
+                         std::vector<Node>& subs);
+  /** extended rewrite
+   *
+   * Prints debug information, indicating the rewrite n ---> ret was found.
+   */
+  inline void debugExtendedRewrite(Node n, Node ret, const char* c) const;
+  //--------------------------------------end generic utilities
+
+  //--------------------------------------theory-specific top-level calls
+  /** extended rewrite arith */
+  Node extendedRewriteArith(Node ret, bool& pol);
+  //--------------------------------------end theory-specific top-level calls
 };
 
 } /* CVC4::theory::quantifiers namespace */
index 3b8d033990ea5c7df4bae9f47bb275738e667bda..5965906cbd04fef281031dc50ab589b48e0d76e1 100644 (file)
@@ -773,13 +773,22 @@ bool TermUtil::containsUninterpretedConstant( Node n ) {
 Node TermUtil::simpleNegate( Node n ){
   if( n.getKind()==OR || n.getKind()==AND ){
     std::vector< Node > children;
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      children.push_back( simpleNegate( n[i] ) );
+    for (const Node& cn : n)
+    {
+      children.push_back(simpleNegate(cn));
     }
     return NodeManager::currentNM()->mkNode( n.getKind()==OR ? AND : OR, children );
-  }else{
-    return n.negate();
   }
+  return n.negate();
+}
+
+Node TermUtil::mkNegate(Kind notk, Node n)
+{
+  if (n.getKind() == notk)
+  {
+    return n[0];
+  }
+  return NodeManager::currentNM()->mkNode(notk, n);
 }
 
 bool TermUtil::isAssoc( Kind k ) {
@@ -912,6 +921,11 @@ Node TermUtil::getTypeValueOffset(TypeNode tn,
   return it->second;
 }
 
+Node TermUtil::mkTypeConst(TypeNode tn, bool pol)
+{
+  return pol ? mkTypeValue(tn, 0) : mkTypeMaxValue(tn);
+}
+
 bool TermUtil::isAntisymmetric(Kind k, Kind& dk)
 {
   if (k == GT)
index 8ec2fc8e2acecd7277811c31bf82c867c31f5575..97f4edcd5cfa735cff0692a7b4e9c6d0c80c3c0f 100644 (file)
@@ -289,6 +289,11 @@ public:
   static int getTermDepth( Node n );
   /** simple negate */
   static Node simpleNegate( Node n );
+  /**
+   * Make negated term, returns the negation of n wrt Kind notk, eliminating
+   * double negation if applicable, e.g. mkNegate( ~, ~x ) ---> x.
+   */
+  static Node mkNegate(Kind notk, Node n);
   /** is assoc */
   static bool isAssoc( Kind k );
   /** is k commutative? */
@@ -364,6 +369,13 @@ public:
   static Node mkTypeValueOffset(TypeNode tn, Node val, int offset, int& status);
   /** make max value, static version of get max value */
   static Node mkTypeMaxValue(TypeNode tn);
+  /**
+   * Make const, returns pol ? mkTypeValue(tn,0) : mkTypeMaxValue(tn).
+   * In other words, this returns either the minimum element of tn if pol is
+   * true, and the maximum element in pol is false. The type tn should have
+   * minimum and maximum elements, for example tn is Bool or BitVector.
+   */
+  static Node mkTypeConst(TypeNode tn, bool pol);
 
   // for higher-order
  private: