Implement aggressive pruning in CAD solver (#7650)
authorGereon Kremer <gkremer@stanford.edu>
Wed, 17 Nov 2021 06:41:51 +0000 (22:41 -0800)
committerGitHub <noreply@github.com>
Wed, 17 Nov 2021 06:41:51 +0000 (06:41 +0000)
This PR implements a more aggressive pruning of redundant intervals, namely intervals that are covered by two other intervals, but not a single one. As already discussed in the respective paper, it is not entirely clear whether this is beneficial as removing such an interval may make the "overlap" smaller and thus the generated interval in the lower dimension may become smaller as well. It is thus only enabled via a (new) option.
Experiments show that such redundant intervals are relatively common (878 benchmarks on QF_NRA), the impact of this option is very limited and not strictly beneficial.

src/options/arith_options.toml
src/theory/arith/nl/cad/cdcac.cpp
src/theory/arith/nl/cad/cdcac_utils.cpp
src/theory/arith/nl/cad/cdcac_utils.h

index 6eb67d6d753f0bd77f5511b2ecd4ae3db562dfc1..3c18404556db87bd6c753895536644a7a14c2892 100644 (file)
@@ -499,6 +499,14 @@ name   = "Arithmetic Theory"
   default    = "false"
   help       = "whether to use the cylindrical algebraic decomposition solver for non-linear arithmetic"
 
+[[option]]
+  name       = "nlCadPrune"
+  category   = "regular"
+  long       = "nl-cad-prune"
+  type       = "bool"
+  default    = "false"
+  help       = "whether to prune intervals more agressively"
+
 [[option]]
   name       = "nlCadUseInitial"
   category   = "regular"
index 0c61c088a4e57e5c68f77cb3e755786d55dac2f4..7c0f4b8923559a253bbd2535e83b597667a1c433 100644 (file)
@@ -694,9 +694,30 @@ bool CDCAC::hasRootBelow(const poly::Polynomial& p,
 
 void CDCAC::pruneRedundantIntervals(std::vector<CACInterval>& intervals)
 {
+  cleanIntervals(intervals);
+  if (options().arith.nlCadPrune)
+  {
+    if (Trace.isOn("cdcac"))
+    {
+      auto copy = intervals;
+      removeRedundantIntervals(intervals);
+      if (copy.size() != intervals.size())
+      {
+        Trace("cdcac") << "Before pruning:";
+        for (const auto& i : copy) Trace("cdcac") << " " << i.d_interval;
+        Trace("cdcac") << std::endl;
+        Trace("cdcac") << "After pruning: ";
+        for (const auto& i : intervals) Trace("cdcac") << " " << i.d_interval;
+        Trace("cdcac") << std::endl;
+      }
+    }
+    else
+    {
+      removeRedundantIntervals(intervals);
+    }
+  }
   if (isProofEnabled())
   {
-    cleanIntervals(intervals);
     d_proof->pruneChildren([&intervals](std::size_t id) {
       return std::find_if(intervals.begin(),
                           intervals.end(),
@@ -704,10 +725,6 @@ void CDCAC::pruneRedundantIntervals(std::vector<CACInterval>& intervals)
              != intervals.end();
     });
   }
-  else
-  {
-    cleanIntervals(intervals);
-  }
 }
 
 }  // namespace cad
index b975a0850107cfeb4b9b15dda728194cddb19cb1..429287681306f2acaf0575e2f29d99884abb6614 100644 (file)
@@ -17,6 +17,8 @@
 
 #ifdef CVC5_POLY_IMP
 
+#include <optional>
+
 #include "theory/arith/nl/cad/projections.h"
 
 namespace cvc5 {
@@ -36,11 +38,12 @@ bool operator<(const CACInterval& lhs, const CACInterval& rhs)
   return lhs.d_interval < rhs.d_interval;
 }
 
+namespace {
 /**
  * Induces an ordering on poly intervals that is suitable for redundancy
  * removal as implemented in clean_intervals.
  */
-inline bool compareForCleanup(const Interval& lhs, const Interval& rhs)
+bool compareForCleanup(const Interval& lhs, const Interval& rhs)
 {
   const lp_value_t* ll = &(lhs.get_internal()->a);
   const lp_value_t* lu =
@@ -74,6 +77,9 @@ inline bool compareForCleanup(const Interval& lhs, const Interval& rhs)
   return false;
 }
 
+/**
+ * Check whether lhs covers rhs.
+ */
 bool intervalCovers(const Interval& lhs, const Interval& rhs)
 {
   const lp_value_t* ll = &(lhs.get_internal()->a);
@@ -106,13 +112,17 @@ bool intervalCovers(const Interval& lhs, const Interval& rhs)
   return true;
 }
 
+/**
+ * Check whether two intervals connect, assuming lhs < rhs.
+ * They connect, if their union has no gap.
+ */
 bool intervalConnect(const Interval& lhs, const Interval& rhs)
 {
   Assert(lhs < rhs) << "Can only check for a connection if lhs < rhs.";
-  const lp_value_t* lu = lhs.get_internal()->is_point
-                             ? &(lhs.get_internal()->a)
-                             : &(lhs.get_internal()->b);
-  const lp_value_t* rl = &(rhs.get_internal()->a);
+
+  const lp_value_t* lu = poly::get_upper(lhs).get_internal();
+  const lp_value_t* rl = poly::get_lower(rhs).get_internal();
+
   int c = lp_value_cmp(lu, rl);
   if (c < 0)
   {
@@ -127,17 +137,39 @@ bool intervalConnect(const Interval& lhs, const Interval& rhs)
     return true;
   }
   Assert(c == 0);
-  if (lhs.get_internal()->is_point || rhs.get_internal()->is_point
-      || !lhs.get_internal()->b_open || !rhs.get_internal()->a_open)
+  if (poly::get_upper_open(lhs) && poly::get_lower_open(rhs))
   {
     Trace("libpoly::interval_connect")
         << lhs << " and " << rhs
-        << " touch and the intermediate point is covered." << std::endl;
-    return true;
+        << " touch and the intermediate point is not covered." << std::endl;
+    return false;
   }
-  return false;
+  Trace("libpoly::interval_connect")
+      << lhs << " and " << rhs
+      << " touch and the intermediate point is covered." << std::endl;
+  return true;
 }
 
+/**
+ * Check whether the union of a and b covers rhs.
+ * First check whether a and b connect, and then defer the containment check to
+ * intervalCovers.
+ */
+std::optional<bool> intervalsCover(const Interval& a,
+                                   const Interval& b,
+                                   const Interval& rhs)
+{
+  if (!intervalConnect(a, b)) return {};
+
+  Interval c(poly::get_lower(a),
+             poly::get_lower_open(a),
+             poly::get_upper(b),
+             poly::get_upper_open(b));
+
+  return intervalCovers(c, rhs);
+}
+}  // namespace
+
 void cleanIntervals(std::vector<CACInterval>& intervals)
 {
   // Simplifies removal of redundancies later on.
@@ -150,11 +182,12 @@ void cleanIntervals(std::vector<CACInterval>& intervals)
               return compareForCleanup(lhs.d_interval, rhs.d_interval);
             });
 
-  // Remove intervals that are covered by others.
-  // Implementation roughly follows
-  // https://en.cppreference.com/w/cpp/algorithm/remove Find first interval that
-  // covers the next one.
+  // First remove intervals that are completely covered by a single other
+  // interval. This corresponds to removing "redundancies of the first kind" as
+  // of 4.5.1 The implementation roughly follows
+  // https://en.cppreference.com/w/cpp/algorithm/remove
   std::size_t first = 0;
+  // Find first interval that is covered.
   for (std::size_t n = intervals.size(); first < n - 1; ++first)
   {
     if (intervalCovers(intervals[first].d_interval,
@@ -184,6 +217,48 @@ void cleanIntervals(std::vector<CACInterval>& intervals)
   }
 }
 
+void removeRedundantIntervals(std::vector<CACInterval>& intervals)
+{
+  // mid-1 -> interval below
+  // mid   -> current interval
+  // right -> interval above
+  size_t mid = 1;
+  size_t right = 2;
+  size_t n = intervals.size();
+  while (right < n)
+  {
+    bool found = false;
+    for (size_t r = right; r < n; ++r)
+    {
+      const auto& below = intervals[mid - 1].d_interval;
+      const auto& middle = intervals[mid].d_interval;
+      const auto& above = intervals[r].d_interval;
+      if (intervalsCover(below, above, middle))
+      {
+        found = true;
+        break;
+      }
+    }
+    if (found)
+    {
+      intervals[mid] = std::move(intervals[right]);
+    }
+    else
+    {
+      ++mid;
+      if (mid < right)
+      {
+        intervals[mid] = std::move(intervals[right]);
+      }
+    }
+    ++right;
+  }
+  while (intervals.size() > mid + 1)
+  {
+    intervals.pop_back();
+  }
+}
+
 std::vector<Node> collectConstraints(const std::vector<CACInterval>& intervals)
 {
   std::vector<Node> res;
index 9eb761ae3998947056919560222548f17bc09578..c53e8fbce8d05dffefa1bb84e57232a1b6b2a760 100644 (file)
@@ -67,19 +67,20 @@ bool operator==(const CACInterval& lhs, const CACInterval& rhs);
 /** Compare two intervals. */
 bool operator<(const CACInterval& lhs, const CACInterval& rhs);
 
-/** Check whether lhs covers rhs. */
-bool intervalCovers(const poly::Interval& lhs, const poly::Interval& rhs);
 /**
- * Check whether two intervals connect, assuming lhs < rhs.
- * They connect, if their union has no gap.
+ * Sort intervals according to section 4.4.1.
+ * Also removes fully redundant intervals as in 4.5. 1.; these are intervals
+ * that are fully contained within a single other interval.
  */
-bool intervalConnect(const poly::Interval& lhs, const poly::Interval& rhs);
+void cleanIntervals(std::vector<CACInterval>& intervals);
 
 /**
- * Sort intervals according to section 4.4.1.
- * Also removes fully redundant intervals as in 4.5. 1.
+ * Removes redundant intervals as in 4.5. 2.; these are intervals that are
+ * covered by two other intervals, but not by a single one. Assumes the
+ * intervals to be sorted and cleaned, i.e. that cleanIntervals(intervals) has
+ * been called beforehand.
  */
-void cleanIntervals(std::vector<CACInterval>& intervals);
+void removeRedundantIntervals(std::vector<CACInterval>& intervals);
 
 /**
  * Collect all origins from the list of intervals to construct the origins for a