Add implementation of learned rewrite pass (#6843)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 6 Jul 2021 20:00:38 +0000 (15:00 -0500)
committerGitHub <noreply@github.com>
Tue, 6 Jul 2021 20:00:38 +0000 (20:00 +0000)
src/preprocessing/passes/learned_rewrite.cpp

index 78588966673704e79ddad991846ca2337691f6d9..4e9aa7bb2de8db2262fa3c9dc1c33fb6baad346b 100644 (file)
@@ -76,7 +76,7 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
     {
       // maybe use the literal for bound inference?
       Kind k = l.getKind();
-      Assert (k != LT && k != GT && k != LEQ);
+      Assert(k != LT && k != GT && k != LEQ);
       if (k == EQUAL || k == GEQ)
       {
         binfer.add(l);
@@ -155,14 +155,295 @@ Node LearnedRewrite::rewriteLearnedRec(Node n,
                                        std::unordered_set<Node>& lems,
                                        std::unordered_map<TNode, Node>& visited)
 {
-  return n;
+  NodeManager* nm = NodeManager::currentNM();
+  std::unordered_map<TNode, Node>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      // mark pre-visited with null; will post-visit to construct final node
+      // in the block below.
+      visited[cur] = Node::null();
+      visit.push_back(cur);
+      visit.insert(visit.end(), cur.begin(), cur.end());
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      bool needsRcons = false;
+      std::vector<Node> children;
+      if (cur.getMetaKind() == kind::metakind::PARAMETERIZED)
+      {
+        children.push_back(cur.getOperator());
+      }
+      for (const Node& cn : cur)
+      {
+        it = visited.find(cn);
+        Assert(it != visited.end());
+        Assert(!it->second.isNull());
+        needsRcons = needsRcons || cn != it->second;
+        children.push_back(it->second);
+      }
+      if (needsRcons)
+      {
+        ret = nm->mkNode(cur.getKind(), children);
+      }
+      // rewrite here
+      ret = rewriteLearned(ret, binfer, lems);
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
 }
 
 Node LearnedRewrite::rewriteLearned(Node n,
                                     arith::BoundInference& binfer,
                                     std::unordered_set<Node>& lems)
 {
-  return n;
+  NodeManager* nm = NodeManager::currentNM();
+  if (lems.find(n) != lems.end())
+  {
+    // n is a learned literal: replace by true, not considered a rewrite
+    // for statistics
+    return nm->mkConst(true);
+  }
+  Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
+  Node nr = Rewriter::rewrite(n);
+  Kind k = nr.getKind();
+  if (k == INTS_DIVISION || k == INTS_MODULUS || k == DIVISION)
+  {
+    // simpler if we know the divisor is non-zero
+    Node num = n[0];
+    Node den = n[1];
+    bool isNonZeroDen = false;
+    if (den.isConst())
+    {
+      isNonZeroDen = (den.getConst<Rational>().sgn() != 0);
+    }
+    else
+    {
+      arith::Bounds db = binfer.get(den);
+      Trace("learned-rewrite-rr-debug")
+          << "Bounds for " << den << " : " << db.lower_value << " "
+          << db.upper_value << std::endl;
+      if (!db.lower_value.isNull()
+          && db.lower_value.getConst<Rational>().sgn() == 1)
+      {
+        isNonZeroDen = true;
+      }
+      else if (!db.upper_value.isNull()
+               && db.upper_value.getConst<Rational>().sgn() == -1)
+      {
+        isNonZeroDen = true;
+      }
+    }
+    if (isNonZeroDen)
+    {
+      Trace("learned-rewrite-rr-debug")
+          << "...non-zero denominator" << std::endl;
+      Kind nk = k;
+      switch (k)
+      {
+        case INTS_DIVISION: nk = INTS_DIVISION_TOTAL; break;
+        case INTS_MODULUS: nk = INTS_MODULUS_TOTAL; break;
+        case DIVISION: nk = DIVISION_TOTAL; break;
+        default: Assert(false); break;
+      }
+      std::vector<Node> children;
+      children.insert(children.end(), n.begin(), n.end());
+      Node ret = nm->mkNode(nk, children);
+      nr = returnRewriteLearned(nr, ret, LearnedRewriteId::NON_ZERO_DEN);
+      nr = Rewriter::rewrite(nr);
+      k = nr.getKind();
+    }
+  }
+  // constant int mod elimination by bound inference
+  if (k == INTS_MODULUS_TOTAL)
+  {
+    Node num = n[0];
+    Node den = n[1];
+    arith::Bounds db = binfer.get(den);
+    if ((!db.lower_value.isNull()
+         && db.lower_value.getConst<Rational>().sgn() == 1)
+        || (!db.upper_value.isNull()
+            && db.upper_value.getConst<Rational>().sgn() == -1))
+    {
+      Rational bden = db.lower_value.isNull()
+                          ? db.lower_value.getConst<Rational>()
+                          : db.upper_value.getConst<Rational>().abs();
+      // if 0 <= UB(num) < LB(den) or 0 <= UB(num) < -UB(den)
+      arith::Bounds nb = binfer.get(num);
+      if (!nb.upper_value.isNull())
+      {
+        Rational bnum = nb.upper_value.getConst<Rational>();
+        if (bnum.sgn() != -1 && bnum < bden)
+        {
+          nr = returnRewriteLearned(nr, nr[0], LearnedRewriteId::INT_MOD_RANGE);
+        }
+      }
+      // could also do num + k*den checks
+    }
+  }
+  else if (k == GEQ || (k == EQUAL && nr[0].getType().isReal()))
+  {
+    std::map<Node, Node> msum;
+    if (ArithMSum::getMonomialSumLit(nr, msum))
+    {
+      Rational lb(0);
+      Rational ub(0);
+      bool lbSuccess = true;
+      bool ubSuccess = true;
+      Rational one(1);
+      if (Trace.isOn("learned-rewrite-arith-lit"))
+      {
+        Trace("learned-rewrite-arith-lit")
+            << "Arithmetic lit: " << nr << std::endl;
+        for (const std::pair<const Node, Node>& m : msum)
+        {
+          Trace("learned-rewrite-arith-lit")
+              << "  " << m.first << ", " << m.second << std::endl;
+        }
+      }
+      for (const std::pair<const Node, Node>& m : msum)
+      {
+        bool isOneCoeff = m.second.isNull();
+        Assert(isOneCoeff || m.second.isConst());
+        if (m.first.isNull())
+        {
+          lb = lb + (isOneCoeff ? one : m.second.getConst<Rational>());
+          ub = ub + (isOneCoeff ? one : m.second.getConst<Rational>());
+        }
+        else
+        {
+          arith::Bounds b = binfer.get(m.first);
+          bool isNeg = !isOneCoeff && m.second.getConst<Rational>().sgn() == -1;
+          // flip lower/upper if negative coefficient
+          TNode l = isNeg ? b.upper_value : b.lower_value;
+          TNode u = isNeg ? b.lower_value : b.upper_value;
+          if (lbSuccess && !l.isNull())
+          {
+            Rational lc = l.getConst<Rational>();
+            lb = lb
+                 + (isOneCoeff ? lc
+                               : Rational(lc * m.second.getConst<Rational>()));
+          }
+          else
+          {
+            lbSuccess = false;
+          }
+          if (ubSuccess && !u.isNull())
+          {
+            Rational uc = u.getConst<Rational>();
+            ub = ub
+                 + (isOneCoeff ? uc
+                               : Rational(uc * m.second.getConst<Rational>()));
+          }
+          else
+          {
+            ubSuccess = false;
+          }
+          if (!lbSuccess && !ubSuccess)
+          {
+            break;
+          }
+        }
+      }
+      if (lbSuccess)
+      {
+        if (lb.sgn() == 1)
+        {
+          // if positive lower bound, then GEQ is true, EQUAL is false
+          Node ret = nm->mkConst(k == GEQ);
+          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_POS_LB);
+          return nr;
+        }
+        else if (lb.sgn() == 0 && k == GEQ)
+        {
+          // zero lower bound, GEQ is true
+          Node ret = nm->mkConst(true);
+          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_ZERO_LB);
+          return nr;
+        }
+      }
+      else if (ubSuccess)
+      {
+        if (ub.sgn() == -1)
+        {
+          // if negative upper bound, then GEQ and EQUAL are false
+          Node ret = nm->mkConst(false);
+          nr = returnRewriteLearned(nr, ret, LearnedRewriteId::PRED_NEG_UB);
+          return nr;
+        }
+      }
+      // inferences based on combining div terms
+      Node currDen;
+      Node currNum;
+      std::vector<Node> sum;
+      size_t divCount = 0;
+      bool divTotal = true;
+      for (const std::pair<const Node, Node>& m : msum)
+      {
+        if (m.first.isNull())
+        {
+          sum.push_back(m.second);
+          continue;
+        }
+        Kind mk = m.first.getKind();
+        if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL)
+        {
+          Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]);
+          divTotal = divTotal && mk == INTS_DIVISION_TOTAL;
+          divCount++;
+          if (currDen.isNull())
+          {
+            currNum = factor;
+            currDen = m.first[1];
+          }
+          else
+          {
+            factor = nm->mkNode(MULT, factor, currDen);
+            currNum = nm->mkNode(MULT, currNum, m.first[1]);
+            currNum = nm->mkNode(PLUS, currNum, factor);
+            currDen = nm->mkNode(MULT, currDen, m.first[1]);
+          }
+        }
+        else
+        {
+          Node factor = ArithMSum::mkCoeffTerm(m.second, m.first);
+          sum.push_back(factor);
+        }
+      }
+      if (divCount >= 2)
+      {
+        SkolemManager* sm = nm->getSkolemManager();
+        Node r = sm->mkDummySkolem("r", nm->integerType());
+        Node d = nm->mkNode(
+            divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen);
+        sum.push_back(d);
+        sum.push_back(r);
+        Node bound =
+            nm->mkNode(AND,
+                       nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r),
+                       nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1))));
+        Node sumn = nm->mkNode(PLUS, sum);
+        Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0)));
+        Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound));
+        Trace("learned-rewrite-div")
+            << "Div collect lemma: " << lemma << std::endl;
+        lems.insert(lemma);
+      }
+    }
+  }
+  return nr;
 }
 
 Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id)