Integrate learned rewrite preprocessing pass (#6840)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 6 Jul 2021 22:44:09 +0000 (17:44 -0500)
committerGitHub <noreply@github.com>
Tue, 6 Jul 2021 22:44:09 +0000 (17:44 -0500)
This adds the learned rewrite preprocessing pass, which rewrites the input formula based on (typically theory specific) reasoning about learned literals. The main motivation is for preprocessing ints division/modulus based on bounds.

src/CMakeLists.txt
src/options/smt_options.toml
src/preprocessing/passes/learned_rewrite.cpp
src/preprocessing/preprocessing_pass_registry.cpp
src/smt/process_assertions.cpp
src/smt/set_defaults.cpp

index 74db7c941d3fcf804019aefc8b08339ff71f767b..3246df654a7ed269686f27b6b9648e6bc68c5e7b 100644 (file)
@@ -88,6 +88,8 @@ libcvc5_add_sources(
   preprocessing/passes/ite_removal.h
   preprocessing/passes/ite_simp.cpp
   preprocessing/passes/ite_simp.h
+  preprocessing/passes/learned_rewrite.cpp
+  preprocessing/passes/learned_rewrite.h
   preprocessing/passes/miplib_trick.cpp
   preprocessing/passes/miplib_trick.h
   preprocessing/passes/nl_ext_purify.cpp
index 4d08aa67204822587ccfc08d81a7ec3855b469cf..9b5a93486341ccfbed63d99a9f6b99aecfd142fd 100644 (file)
@@ -47,6 +47,14 @@ name   = "SMT Layer"
   default    = "true"
   help       = "use static learning (on by default)"
 
+[[option]]
+  name       = "learnedRewrite"
+  category   = "regular"
+  long       = "learned-rewrite"
+  type       = "bool"
+  default    = "false"
+  help       = "rewrite the input based on learned literals"
+
 [[option]]
   name       = "expandDefinitions"
   long       = "expand-definitions"
index 4e9aa7bb2de8db2262fa3c9dc1c33fb6baad346b..fd3cf832b1d2362d79a535c7231dd8d718935d0b 100644 (file)
@@ -60,6 +60,7 @@ LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
 PreprocessingPassResult LearnedRewrite::applyInternal(
     AssertionPipeline* assertionsToPreprocess)
 {
+  NodeManager* nm = NodeManager::currentNM();
   arith::BoundInference binfer;
   std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
   std::unordered_set<Node> llrw;
@@ -72,14 +73,29 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
   else
   {
     Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
+    std::map<Node, Node> originLit;
     for (const Node& l : learnedLits)
     {
       // maybe use the literal for bound inference?
-      Kind k = l.getKind();
-      Assert(k != LT && k != GT && k != LEQ);
-      if (k == EQUAL || k == GEQ)
+      bool pol = l.getKind()!=NOT;
+      TNode atom = pol ? l : l[0];
+      Kind ak = atom.getKind();
+      Assert(ak != LT && ak != GT && ak != LEQ);
+      if ((ak == EQUAL && pol) || ak == GEQ)
       {
-        binfer.add(l);
+        // provide as < if negated >=
+        Node atomu;
+        if (!pol)
+        {
+          atomu = nm->mkNode(LT, atom[0], atom[1]);
+          originLit[atomu] = l;
+        }
+        else
+        {
+          atomu = l;
+          originLit[l] = l;
+        }
+        binfer.add(atomu);
       }
       Trace("learned-rewrite-ll") << "- " << l << std::endl;
     }
@@ -93,7 +109,8 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
         Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
         if (!origin.isNull())
         {
-          llrw.insert(origin);
+          Assert (originLit.find(origin)!=originLit.end());
+          llrw.insert(originLit[origin]);
         }
       }
     }
@@ -139,7 +156,6 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
   // unchanged.
   if (!llrw.empty())
   {
-    NodeManager* nm = NodeManager::currentNM();
     std::vector<Node> llrvec(llrw.begin(), llrw.end());
     Node llc = nm->mkAnd(llrvec);
     Trace("learned-rewrite-assert")
@@ -165,7 +181,13 @@ Node LearnedRewrite::rewriteLearnedRec(Node n,
     cur = visit.back();
     visit.pop_back();
     it = visited.find(cur);
-
+    if (lems.find(cur) != lems.end())
+    {
+      // n is a learned literal: replace by true, not considered a rewrite
+      // for statistics
+      visited[cur] = nm->mkConst(true);
+      continue;
+    }
     if (it == visited.end())
     {
       // mark pre-visited with null; will post-visit to construct final node
@@ -210,12 +232,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
                                     std::unordered_set<Node>& lems)
 {
   NodeManager* nm = NodeManager::currentNM();
-  if (lems.find(n) != lems.end())
-  {
-    // n is a learned literal: replace by true, not considered a rewrite
-    // for statistics
-    return nm->mkConst(true);
-  }
   Trace("learned-rewrite-rr-debug") << "Rewrite " << n << std::endl;
   Node nr = Rewriter::rewrite(n);
   Kind k = nr.getKind();
@@ -384,63 +400,6 @@ Node LearnedRewrite::rewriteLearned(Node n,
           return nr;
         }
       }
-      // inferences based on combining div terms
-      Node currDen;
-      Node currNum;
-      std::vector<Node> sum;
-      size_t divCount = 0;
-      bool divTotal = true;
-      for (const std::pair<const Node, Node>& m : msum)
-      {
-        if (m.first.isNull())
-        {
-          sum.push_back(m.second);
-          continue;
-        }
-        Kind mk = m.first.getKind();
-        if (mk == INTS_DIVISION || mk == INTS_DIVISION_TOTAL)
-        {
-          Node factor = ArithMSum::mkCoeffTerm(m.second, m.first[0]);
-          divTotal = divTotal && mk == INTS_DIVISION_TOTAL;
-          divCount++;
-          if (currDen.isNull())
-          {
-            currNum = factor;
-            currDen = m.first[1];
-          }
-          else
-          {
-            factor = nm->mkNode(MULT, factor, currDen);
-            currNum = nm->mkNode(MULT, currNum, m.first[1]);
-            currNum = nm->mkNode(PLUS, currNum, factor);
-            currDen = nm->mkNode(MULT, currDen, m.first[1]);
-          }
-        }
-        else
-        {
-          Node factor = ArithMSum::mkCoeffTerm(m.second, m.first);
-          sum.push_back(factor);
-        }
-      }
-      if (divCount >= 2)
-      {
-        SkolemManager* sm = nm->getSkolemManager();
-        Node r = sm->mkDummySkolem("r", nm->integerType());
-        Node d = nm->mkNode(
-            divTotal ? INTS_DIVISION_TOTAL : INTS_DIVISION, currNum, currDen);
-        sum.push_back(d);
-        sum.push_back(r);
-        Node bound =
-            nm->mkNode(AND,
-                       nm->mkNode(LEQ, nm->mkConst(-Rational(divCount - 1)), r),
-                       nm->mkNode(LEQ, r, nm->mkConst(Rational(divCount - 1))));
-        Node sumn = nm->mkNode(PLUS, sum);
-        Node lit = nm->mkNode(k, sumn, nm->mkConst(Rational(0)));
-        Node lemma = nm->mkNode(IMPLIES, nr, nm->mkNode(AND, lit, bound));
-        Trace("learned-rewrite-div")
-            << "Div collect lemma: " << lemma << std::endl;
-        lems.insert(lemma);
-      }
     }
   }
   return nr;
index 6f846dc7488b86f1252f213a4911a875b546b1da..f0bd5af86445b85e4001e9a98cf703046fbaa30c 100644 (file)
@@ -41,6 +41,7 @@
 #include "preprocessing/passes/int_to_bv.h"
 #include "preprocessing/passes/ite_removal.h"
 #include "preprocessing/passes/ite_simp.h"
+#include "preprocessing/passes/learned_rewrite.h"
 #include "preprocessing/passes/miplib_trick.h"
 #include "preprocessing/passes/nl_ext_purify.h"
 #include "preprocessing/passes/non_clausal_simp.h"
@@ -126,6 +127,7 @@ PreprocessingPassRegistry::PreprocessingPassRegistry()
   registerPassInfo("global-negate", callCtor<GlobalNegate>);
   registerPassInfo("int-to-bv", callCtor<IntToBV>);
   registerPassInfo("bv-to-int", callCtor<BVToInt>);
+  registerPassInfo("learned-rewrite", callCtor<LearnedRewrite>);
   registerPassInfo("foreign-theory-rewrite", callCtor<ForeignTheoryRewrite>);
   registerPassInfo("synth-rr", callCtor<SynthRewRulesPass>);
   registerPassInfo("real-to-int", callCtor<RealToInt>);
index cf747c36000da109104e4f3ab3b764c1e1c0f04a..a9426d5bd38d04bd3fc8936e1cdea34bb6cfdf43 100644 (file)
@@ -288,6 +288,11 @@ bool ProcessAssertions::apply(Assertions& as)
   }
   Debug("smt") << " assertions     : " << assertions.size() << endl;
 
+  if (options::learnedRewrite())
+  {
+    d_passes["learned-rewrite"]->apply(&assertions);
+  }
+
   if (options::earlyIteRemoval())
   {
     d_smtStats.d_numAssertionsPre += assertions.size();
index ee3701d513a490c3d2d1d153f9c8a081e6aed7f1..229fdeec55b2c3b99017fd2e657613611657c395 100644 (file)
@@ -489,15 +489,27 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
       opts.smt.simplificationMode = options::SimplificationMode::NONE;
     }
 
+    if (options::learnedRewrite())
+    {
+      if (opts.smt.learnedRewriteWasSetByUser)
+      {
+        throw OptionException(
+            "learned rewrites not supported with unsat cores");
+      }
+      Notice() << "SmtEngine: turning off learned rewrites to support "
+                  "unsat cores\n";
+      opts.smt.learnedRewrite = false;
+    }
+
     if (options::pbRewrites())
     {
       if (opts.arith.pbRewritesWasSetByUser)
       {
         throw OptionException(
-            "pseudoboolean rewrites not supported with old unsat cores");
+            "pseudoboolean rewrites not supported with unsat cores");
       }
       Notice() << "SmtEngine: turning off pseudoboolean rewrites to support "
-                  "old unsat cores\n";
+                  "unsat cores\n";
       opts.arith.pbRewrites = false;
     }
 
@@ -505,10 +517,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     {
       if (opts.smt.sortInferenceWasSetByUser)
       {
-        throw OptionException(
-            "sort inference not supported with old unsat cores");
+        throw OptionException("sort inference not supported with unsat cores");
       }
-      Notice() << "SmtEngine: turning off sort inference to support old unsat "
+      Notice() << "SmtEngine: turning off sort inference to support unsat "
                   "cores\n";
       opts.smt.sortInference = false;
     }
@@ -518,9 +529,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
       if (opts.quantifiers.preSkolemQuantWasSetByUser)
       {
         throw OptionException(
-            "pre-skolemization not supported with old unsat cores");
+            "pre-skolemization not supported with unsat cores");
       }
-      Notice() << "SmtEngine: turning off pre-skolemization to support old "
+      Notice() << "SmtEngine: turning off pre-skolemization to support "
                   "unsat cores\n";
       opts.quantifiers.preSkolemQuant = false;
     }
@@ -529,9 +540,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     {
       if (opts.bv.bitvectorToBoolWasSetByUser)
       {
-        throw OptionException("bv-to-bool not supported with old unsat cores");
+        throw OptionException("bv-to-bool not supported with unsat cores");
       }
-      Notice() << "SmtEngine: turning off bitvector-to-bool to support old "
+      Notice() << "SmtEngine: turning off bitvector-to-bool to support "
                   "unsat cores\n";
       opts.bv.bitvectorToBool = false;
     }
@@ -541,10 +552,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
       if (opts.bv.boolToBitvectorWasSetByUser)
       {
         throw OptionException(
-            "bool-to-bv != off not supported with old unsat cores");
+            "bool-to-bv != off not supported with unsat cores");
       }
-      Notice()
-          << "SmtEngine: turning off bool-to-bv to support old unsat cores\n";
+      Notice() << "SmtEngine: turning off bool-to-bv to support unsat cores\n";
       opts.bv.boolToBitvector = options::BoolToBVMode::OFF;
     }
 
@@ -552,11 +562,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     {
       if (opts.bv.bvIntroducePow2WasSetByUser)
       {
-        throw OptionException(
-            "bv-intro-pow2 not supported with old unsat cores");
+        throw OptionException("bv-intro-pow2 not supported with unsat cores");
       }
-      Notice()
-          << "SmtEngine: turning off bv-intro-pow2 to support old unsat cores";
+      Notice() << "SmtEngine: turning off bv-intro-pow2 to support unsat cores";
       opts.bv.bvIntroducePow2 = false;
     }
 
@@ -564,10 +572,9 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     {
       if (opts.smt.repeatSimpWasSetByUser)
       {
-        throw OptionException("repeat-simp not supported with old unsat cores");
+        throw OptionException("repeat-simp not supported with unsat cores");
       }
-      Notice()
-          << "SmtEngine: turning off repeat-simp to support old unsat cores\n";
+      Notice() << "SmtEngine: turning off repeat-simp to support unsat cores\n";
       opts.smt.repeatSimp = false;
     }
 
@@ -575,22 +582,21 @@ void setDefaults(LogicInfo& logic, bool isInternalSubsolver)
     {
       if (opts.quantifiers.globalNegateWasSetByUser)
       {
-        throw OptionException(
-            "global-negate not supported with old unsat cores");
+        throw OptionException("global-negate not supported with unsat cores");
       }
-      Notice() << "SmtEngine: turning off global-negate to support old unsat "
+      Notice() << "SmtEngine: turning off global-negate to support unsat "
                   "cores\n";
       opts.quantifiers.globalNegate = false;
     }
 
     if (options::bitvectorAig())
     {
-      throw OptionException("bitblast-aig not supported with old unsat cores");
+      throw OptionException("bitblast-aig not supported with unsat cores");
     }
 
     if (options::doITESimp())
     {
-      throw OptionException("ITE simp not supported with old unsat cores");
+      throw OptionException("ITE simp not supported with unsat cores");
     }
   }
   else