Add learned rewrite preprocessing pass (#6842)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 6 Jul 2021 17:23:45 +0000 (12:23 -0500)
committerGitHub <noreply@github.com>
Tue, 6 Jul 2021 17:23:45 +0000 (17:23 +0000)
Adds the basic skeleton of the pass.

src/preprocessing/passes/learned_rewrite.cpp [new file with mode: 0644]
src/preprocessing/passes/learned_rewrite.h [new file with mode: 0644]

diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp
new file mode 100644 (file)
index 0000000..7858896
--- /dev/null
@@ -0,0 +1,181 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Rewriting based on learned literals
+ */
+
+#include "preprocessing/passes/learned_rewrite.h"
+
+#include "expr/skolem_manager.h"
+#include "expr/term_context_stack.h"
+#include "preprocessing/assertion_pipeline.h"
+#include "smt/smt_statistics_registry.h"
+#include "theory/arith/arith_msum.h"
+#include "theory/rewriter.h"
+#include "util/rational.h"
+
+using namespace cvc5::theory;
+using namespace cvc5::kind;
+
+namespace cvc5 {
+namespace preprocessing {
+namespace passes {
+
+const char* toString(LearnedRewriteId i)
+{
+  switch (i)
+  {
+    case LearnedRewriteId::NON_ZERO_DEN: return "NON_ZERO_DEN";
+    case LearnedRewriteId::INT_MOD_RANGE: return "INT_MOD_RANGE";
+    case LearnedRewriteId::PRED_POS_LB: return "PRED_POS_LB";
+    case LearnedRewriteId::PRED_ZERO_LB: return "PRED_ZERO_LB";
+    case LearnedRewriteId::PRED_NEG_UB: return "PRED_NEG_UB";
+    case LearnedRewriteId::NONE: return "NONE";
+    default: return "?LearnedRewriteId?";
+  }
+}
+
+std::ostream& operator<<(std::ostream& out, LearnedRewriteId i)
+{
+  out << toString(i);
+  return out;
+}
+
+LearnedRewrite::LearnedRewrite(PreprocessingPassContext* preprocContext)
+    : PreprocessingPass(preprocContext, "learned-rewrite"),
+      d_lrewCount(smtStatisticsRegistry().registerHistogram<LearnedRewriteId>(
+          "LearnedRewrite::lrewCount"))
+{
+}
+
+PreprocessingPassResult LearnedRewrite::applyInternal(
+    AssertionPipeline* assertionsToPreprocess)
+{
+  arith::BoundInference binfer;
+  std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
+  std::unordered_set<Node> llrw;
+  std::unordered_map<TNode, Node> visited;
+  if (learnedLits.empty())
+  {
+    Trace("learned-rewrite-ll") << "No learned literals" << std::endl;
+    return PreprocessingPassResult::NO_CONFLICT;
+  }
+  else
+  {
+    Trace("learned-rewrite-ll") << "Learned literals:" << std::endl;
+    for (const Node& l : learnedLits)
+    {
+      // maybe use the literal for bound inference?
+      Kind k = l.getKind();
+      Assert (k != LT && k != GT && k != LEQ);
+      if (k == EQUAL || k == GEQ)
+      {
+        binfer.add(l);
+      }
+      Trace("learned-rewrite-ll") << "- " << l << std::endl;
+    }
+    const std::map<Node, arith::Bounds>& bs = binfer.get();
+    // get the literals that were critical, i.e. used in the derivation of a
+    // bound
+    for (const std::pair<const Node, arith::Bounds>& b : bs)
+    {
+      for (size_t i = 0; i < 2; i++)
+      {
+        Node origin = i == 0 ? b.second.lower_origin : b.second.upper_origin;
+        if (!origin.isNull())
+        {
+          llrw.insert(origin);
+        }
+      }
+    }
+    // rewrite the non-critical learned literals, some may be redundant
+    for (const Node& l : learnedLits)
+    {
+      if (llrw.find(l) != llrw.end())
+      {
+        continue;
+      }
+      Node e = rewriteLearnedRec(l, binfer, llrw, visited);
+      if (e.isConst())
+      {
+        // ignore true
+        if (e.getConst<bool>())
+        {
+          continue;
+        }
+        // conflict, we are done
+        assertionsToPreprocess->push_back(e);
+        return PreprocessingPassResult::CONFLICT;
+      }
+      llrw.insert(e);
+    }
+    Trace("learned-rewrite-ll") << "end" << std::endl;
+  }
+  size_t size = assertionsToPreprocess->size();
+  for (size_t i = 0; i < size; ++i)
+  {
+    Node prev = (*assertionsToPreprocess)[i];
+    Trace("learned-rewrite-assert")
+        << "LearnedRewrite: assert: " << prev << std::endl;
+    Node e = rewriteLearnedRec(prev, binfer, llrw, visited);
+    if (e != prev)
+    {
+      Trace("learned-rewrite-assert")
+          << ".......................: " << e << std::endl;
+      assertionsToPreprocess->replace(i, e);
+    }
+  }
+  // Add the conjunction of learned literals back to assertions. Notice that
+  // in some cases we may add top-level assertions back to the assertion list
+  // unchanged.
+  if (!llrw.empty())
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    std::vector<Node> llrvec(llrw.begin(), llrw.end());
+    Node llc = nm->mkAnd(llrvec);
+    Trace("learned-rewrite-assert")
+        << "Re-add rewritten learned conjunction: " << llc << std::endl;
+    assertionsToPreprocess->push_back(llc);
+  }
+
+  return PreprocessingPassResult::NO_CONFLICT;
+}
+
+Node LearnedRewrite::rewriteLearnedRec(Node n,
+                                       arith::BoundInference& binfer,
+                                       std::unordered_set<Node>& lems,
+                                       std::unordered_map<TNode, Node>& visited)
+{
+  return n;
+}
+
+Node LearnedRewrite::rewriteLearned(Node n,
+                                    arith::BoundInference& binfer,
+                                    std::unordered_set<Node>& lems)
+{
+  return n;
+}
+
+Node LearnedRewrite::returnRewriteLearned(Node n, Node nr, LearnedRewriteId id)
+{
+  if (Trace.isOn("learned-rewrite"))
+  {
+    Trace("learned-rewrite") << "LearnedRewrite::Rewrite: (" << id << ") " << n
+                             << " == " << nr << std::endl;
+  }
+  d_lrewCount << id;
+  return nr;
+}
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace cvc5
diff --git a/src/preprocessing/passes/learned_rewrite.h b/src/preprocessing/passes/learned_rewrite.h
new file mode 100644 (file)
index 0000000..4f3a51d
--- /dev/null
@@ -0,0 +1,108 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Rewriting based on learned literals
+ */
+#include "cvc5_private.h"
+
+#ifndef CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H
+#define CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H
+
+#include "preprocessing/preprocessing_pass.h"
+#include "preprocessing/preprocessing_pass_context.h"
+#include "theory/arith/bound_inference.h"
+#include "util/statistics_stats.h"
+
+#include <iosfwd>
+
+namespace cvc5 {
+namespace preprocessing {
+namespace passes {
+
+/**
+ * Learned rewrites in the pass below.
+ */
+enum class LearnedRewriteId
+{
+  // Elimination of division, int division, int modulus due to non-zero
+  // denominator. e.g. (not (= y 0)) => (div x y) ---> (div_total x y)
+  NON_ZERO_DEN,
+  // Elimination of int modulus due to range.
+  // e.g. (and (<= 0 x) (< x n)) => (mod x n) ---> x
+  INT_MOD_RANGE,
+  // e.g. (>= c 0) => (>= p 0) ---> true where c is inferred const lower bound
+  PRED_POS_LB,
+  // e.g. (= c 0) => (>= p 0) ---> true where c is inferred const lower bound
+  PRED_ZERO_LB,
+  // e.g. (> c 0) => (>= p 0) ---> false where c is inferred const upper bound
+  PRED_NEG_UB,
+
+  //-------------------------------------- NONE
+  NONE
+};
+
+/**
+ * Converts an learned rewrite id to a string.
+ *
+ * @param i The learned rewrite identifier
+ * @return The name of the learned rewrite identifier
+ */
+const char* toString(LearnedRewriteId i);
+
+/**
+ * Writes an learned rewrite identifier to a stream.
+ *
+ * @param out The stream to write to
+ * @param i The learned rewrite identifier to write to the stream
+ * @return The stream
+ */
+std::ostream& operator<<(std::ostream& out, LearnedRewriteId i);
+
+/**
+ * Applies learned rewriting. This rewrites the input based on learned literals.
+ * This in particular does rewriting that goes beyond what is done in
+ * non-clausal simplification, where equality substitutions + constant
+ * propagations are performed. In particular, this pass applies rewriting
+ * based on e.g. bound inference for arithmetic.
+ */
+class LearnedRewrite : public PreprocessingPass
+{
+ public:
+  LearnedRewrite(PreprocessingPassContext* preprocContext);
+
+ protected:
+  PreprocessingPassResult applyInternal(
+      AssertionPipeline* assertionsToPreprocess) override;
+  /**
+   * Apply rewrite with learned literals, traverses n.
+   */
+  Node rewriteLearnedRec(Node n,
+                         theory::arith::BoundInference& binfer,
+                         std::unordered_set<Node>& lems,
+                         std::unordered_map<TNode, Node>& visited);
+  /**
+   * Learned rewrite to n, single step.
+   */
+  Node rewriteLearned(Node n,
+                      theory::arith::BoundInference& binfer,
+                      std::unordered_set<Node>& lems);
+  /** Return learned rewrite */
+  Node returnRewriteLearned(Node n, Node nr, LearnedRewriteId id);
+  /** Counts number of applications of learned rewrites */
+  HistogramStat<LearnedRewriteId> d_lrewCount;
+};
+
+}  // namespace passes
+}  // namespace preprocessing
+}  // namespace cvc5
+
+#endif /* CVC5__PREPROCESSING__PASSES__LEARNED_REWRITE_H */