Use BoundInference in nonlinear extension (#5359)
authorGereon Kremer <gereon.kremer@cs.rwth-aachen.de>
Fri, 30 Oct 2020 08:07:53 +0000 (09:07 +0100)
committerGitHub <noreply@github.com>
Fri, 30 Oct 2020 08:07:53 +0000 (09:07 +0100)
Currently the NonlinearExtensions uses a custom logic to eliminate redundant bounds and perform tightening on bound integer terms. As these replacements are not recorded, incorrect conflicts are being sent to the InferenceManager.
This PR replaces this logic by the BoundInference class and fixes the issues with conflicts by
- allowing BoundInference to collect bounds on arbitrary left hand sides (instead of only variables),
- improving origin tracking in BoundInference by explicitly constructing the new bound constraints,
- adding tightening of integer bounds,
- emitting lemmas instead of conflicts, and finally
- replacing the current logic by using the BoundInference class.

12 files changed:
src/theory/arith/bound_inference.cpp
src/theory/arith/bound_inference.h
src/theory/arith/inference_manager.cpp
src/theory/arith/inference_manager.h
src/theory/arith/nl/cad_solver.cpp
src/theory/arith/nl/icp/contraction_origins.cpp
src/theory/arith/nl/icp/contraction_origins.h
src/theory/arith/nl/icp/icp_solver.cpp
src/theory/arith/nl/icp/icp_solver.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/nonlinear_extension.h
src/theory/arith/nl/poly_conversion.cpp

index 92e71bf14a431ffcbc7fd0975bfaab4080313cec..c8a9527c7a4e94d864e82ab5dfad3fbd9159d742 100644 (file)
@@ -22,69 +22,24 @@ namespace theory {
 namespace arith {
 
 std::ostream& operator<<(std::ostream& os, const Bounds& b) {
-
-  return os << (b.lower_strict ? '(' : '[') << b.lower << " .. " << b.upper
-            << (b.upper_strict ? ')' : ']');
-
-}
-
-void BoundInference::update_lower_bound(const Node& origin,
-                                        const Node& variable,
-                                        const Node& value,
-                                        bool strict)
-{
-  // variable > or >= value because of origin
-  Trace("nl-icp") << "\tNew bound " << variable << (strict ? ">" : ">=")
-                  << value << " due to " << origin << std::endl;
-  Bounds& b = get_or_add(variable);
-  if (b.lower.isNull() || b.lower.getConst<Rational>() < value.getConst<Rational>())
-  {
-    b.lower = value;
-    b.lower_strict = strict;
-    b.lower_origin = origin;
-  }
-  else if (strict && b.lower == value)
-  {
-    b.lower_strict = strict;
-    b.lower_origin = origin;
-  }
-}
-void BoundInference::update_upper_bound(const Node& origin,
-                                        const Node& variable,
-                                        const Node& value,
-                                        bool strict)
-{
-  // variable < or <= value because of origin
-  Trace("nl-icp") << "\tNew bound " << variable << (strict ? "<" : "<=")
-                  << value << " due to " << origin << std::endl;
-  Bounds& b = get_or_add(variable);
-  if (b.upper.isNull() || b.upper.getConst<Rational>() > value.getConst<Rational>())
-  {
-    b.upper = value;
-    b.upper_strict = strict;
-    b.upper_origin = origin;
-  }
-  else if (strict && b.upper == value)
-  {
-    b.upper_strict = strict;
-    b.upper_origin = origin;
-  }
+  return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. "
+            << b.upper_value << (b.upper_strict ? ')' : ']');
 }
 
 void BoundInference::reset() { d_bounds.clear(); }
 
-Bounds& BoundInference::get_or_add(const Node& v)
+Bounds& BoundInference::get_or_add(const Node& lhs)
 {
-  auto it = d_bounds.find(v);
+  auto it = d_bounds.find(lhs);
   if (it == d_bounds.end())
   {
-    it = d_bounds.emplace(v, Bounds()).first;
+    it = d_bounds.emplace(lhs, Bounds()).first;
   }
   return it->second;
 }
-Bounds BoundInference::get(const Node& v) const
+Bounds BoundInference::get(const Node& lhs) const
 {
-  auto it = d_bounds.find(v);
+  auto it = d_bounds.find(lhs);
   if (it == d_bounds.end())
   {
     return Bounds{};
@@ -93,7 +48,7 @@ Bounds BoundInference::get(const Node& v) const
 }
 
 const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
-bool BoundInference::add(const Node& n)
+bool BoundInference::add(const Node& n, bool onlyVariables)
 {
   Node tmp = Rewriter::rewrite(n);
   if (tmp.getKind() == Kind::CONST_BOOLEAN)
@@ -103,46 +58,175 @@ bool BoundInference::add(const Node& n)
   // Parse the node as a comparison
   auto comp = Comparison::parseNormalForm(tmp);
   auto dec = comp.decompose(true);
-  if (std::get<0>(dec).isVariable())
+  if (onlyVariables && !std::get<0>(dec).isVariable())
   {
-    Variable v = std::get<0>(dec).getVariable();
-    Kind relation = std::get<1>(dec);
-    if (relation == Kind::DISTINCT) return false;
-    Constant bound = std::get<2>(dec);
-    // has the form  v  ~relation~  bound
+    return false;
+  }
+
+  Node lhs = std::get<0>(dec).getNode();
+  Kind relation = std::get<1>(dec);
+  if (relation == Kind::DISTINCT) return false;
+  Node bound = std::get<2>(dec).getNode();
+  // has the form  lhs  ~relation~  bound
 
+  if (lhs.getType().isInteger())
+  {
+    Rational br = bound.getConst<Rational>();
+    auto* nm = NodeManager::currentNM();
     switch (relation)
     {
-      case Kind::LEQ:
-        update_upper_bound(n, v.getNode(), bound.getNode(), false);
-        break;
+      case Kind::LEQ: bound = nm->mkConst<Rational>(br.floor()); break;
       case Kind::LT:
-        update_upper_bound(n, v.getNode(), bound.getNode(), true);
-        break;
-      case Kind::EQUAL:
-        update_lower_bound(n, v.getNode(), bound.getNode(), false);
-        update_upper_bound(n, v.getNode(), bound.getNode(), false);
+        bound = nm->mkConst<Rational>((br - 1).ceiling());
+        relation = Kind::LEQ;
         break;
       case Kind::GT:
-        update_lower_bound(n, v.getNode(), bound.getNode(), true);
+        bound = nm->mkConst<Rational>((br + 1).floor());
+        relation = Kind::GEQ;
         break;
-      case Kind::GEQ:
-        update_lower_bound(n, v.getNode(), bound.getNode(), false);
-        break;
-      default: Assert(false);
+      case Kind::GEQ: bound = nm->mkConst<Rational>(br.ceiling()); break;
+      default:;
+    }
+    Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " "
+                       << relation << " " << bound << std::endl;
+  }
+
+  switch (relation)
+  {
+    case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break;
+    case Kind::LT: update_upper_bound(n, lhs, bound, true); break;
+    case Kind::EQUAL:
+      update_lower_bound(n, lhs, bound, false);
+      update_upper_bound(n, lhs, bound, false);
+      break;
+    case Kind::GT: update_lower_bound(n, lhs, bound, true); break;
+    case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break;
+    default: Assert(false);
+  }
+  return true;
+}
+
+void BoundInference::replaceByOrigins(std::vector<Node>& nodes) const
+{
+  std::vector<Node> toAdd;
+  for (auto& n : nodes)
+  {
+    for (const auto& b : d_bounds)
+    {
+      if (n == b.second.lower_bound && n == b.second.upper_bound)
+      {
+        if (n != b.second.lower_origin && n != b.second.upper_origin)
+        {
+          Trace("bound-inf")
+              << "Replace " << n << " by origins " << b.second.lower_origin
+              << " and " << b.second.upper_origin << std::endl;
+          n = b.second.lower_origin;
+          toAdd.emplace_back(b.second.upper_origin);
+        }
+      }
+      else if (n == b.second.lower_bound)
+      {
+        if (n != b.second.lower_origin)
+        {
+          Trace("bound-inf") << "Replace " << n << " by origin "
+                             << b.second.lower_origin << std::endl;
+          n = b.second.lower_origin;
+        }
+      }
+      else if (n == b.second.upper_bound)
+      {
+        if (n != b.second.upper_origin)
+        {
+          Trace("bound-inf") << "Replace " << n << " by origin "
+                             << b.second.upper_origin << std::endl;
+          n = b.second.upper_origin;
+        }
+      }
+    }
+  }
+  nodes.insert(nodes.end(), toAdd.begin(), toAdd.end());
+}
+
+void BoundInference::update_lower_bound(const Node& origin,
+                                        const Node& lhs,
+                                        const Node& value,
+                                        bool strict)
+{
+  // lhs > or >= value because of origin
+  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value
+                     << " due to " << origin << std::endl;
+  Bounds& b = get_or_add(lhs);
+  if (b.lower_value.isNull()
+      || b.lower_value.getConst<Rational>() < value.getConst<Rational>())
+  {
+    auto* nm = NodeManager::currentNM();
+    b.lower_value = value;
+    b.lower_strict = strict;
+
+    b.lower_origin = origin;
+
+    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
+    {
+      b.lower_bound = b.upper_bound =
+          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+    }
+    else
+    {
+      b.lower_bound = Rewriter::rewrite(
+          nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
+    }
+  }
+  else if (strict && b.lower_value == value)
+  {
+    auto* nm = NodeManager::currentNM();
+    b.lower_strict = strict;
+    b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
+    b.lower_origin = origin;
+  }
+}
+void BoundInference::update_upper_bound(const Node& origin,
+                                        const Node& lhs,
+                                        const Node& value,
+                                        bool strict)
+{
+  // lhs < or <= value because of origin
+  Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value
+                     << " due to " << origin << std::endl;
+  Bounds& b = get_or_add(lhs);
+  if (b.upper_value.isNull()
+      || b.upper_value.getConst<Rational>() > value.getConst<Rational>())
+  {
+    auto* nm = NodeManager::currentNM();
+    b.upper_value = value;
+    b.upper_strict = strict;
+    b.upper_origin = origin;
+    if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
+    {
+      b.lower_bound = b.upper_bound =
+          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+    }
+    else
+    {
+      b.upper_bound = Rewriter::rewrite(
+          nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
     }
-    return true;
   }
-  return false;
+  else if (strict && b.upper_value == value)
+  {
+    auto* nm = NodeManager::currentNM();
+    b.upper_strict = strict;
+    b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
+    b.upper_origin = origin;
+  }
 }
 
 std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
 {
   os << "Bounds:" << std::endl;
-  for (const auto& vb : bi.get())
+  for (const auto& b : bi.get())
   {
-    os << "\t" << vb.first << " -> " << vb.second.lower << ".."
-       << vb.second.upper << std::endl;
+    os << "\t" << b.first << " -> " << b.second.lower_value << ".."
+       << b.second.upper_value << std::endl;
   }
   return os;
 }
@@ -153,8 +237,10 @@ std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertio
     bi.add(a);
   }
   std::map<Node, std::pair<Node,Node>> res;
-  for (const auto& vb: bi.get()) {
-    res.emplace(vb.first, std::make_pair(vb.second.lower, vb.second.upper));
+  for (const auto& b : bi.get())
+  {
+    res.emplace(b.first,
+                std::make_pair(b.second.lower_value, b.second.upper_value));
   }
   return res;
 }
index b360ad421b693e52c8e196a12350eff38894d36e..174ba3a0f16c256175df349f6ea47adb73601061 100644 (file)
@@ -27,66 +27,84 @@ namespace arith {
 
   struct Bounds
   {
-    /** The lower bound */
-    Node lower;
+    /** The lower bound value */
+    Node lower_value;
     /** Whether the lower bound is strict or weak */
     bool lower_strict = true;
+    /** The lower bound as constraint */
+    Node lower_bound;
     /** The origin of the lower bound */
     Node lower_origin;
-    /** The upper bound */
-    Node upper;
+    /** The upper bound value */
+    Node upper_value;
     /** Whether the upper bound is strict or weak */
     bool upper_strict = true;
+    /** The upper bound as constraint */
+    Node upper_bound;
     /** The origin of the upper bound */
     Node upper_origin;
   };
 
-/** Print the current variable bounds. */
-std::ostream& operator<<(std::ostream& os, const Bounds& b);
-
-/**
- * A utility class that extracts direct bounds on single variables from theory
- * atoms.
- */
-class BoundInference
-{
-  /** The currently strictest bounds for every variable. */
-  std::map<Node, Bounds> d_bounds;
-
-  /** Updates the lower bound for the given variable */
-  void update_lower_bound(const Node& origin,
-                          const Node& variable,
-                          const Node& value,
-                          bool strict);
-  /** Updates the upper bound for the given variable */
-  void update_upper_bound(const Node& origin,
-                          const Node& variable,
-                          const Node& value,
-                          bool strict);
-
- public:
-  void reset();
+  /** Print the current bounds. */
+  std::ostream& operator<<(std::ostream& os, const Bounds& b);
 
   /**
-   * Get the current interval for v. Creates a new (full) interval if
-   * necessary.
-   */
-  Bounds& get_or_add(const Node& v);
-  /**
-   * Get the current interval for v. Returns a full interval if no interval was
-   * derived yet.
+   * A utility class that extracts direct bounds on arithmetic terms from theory
+   * atoms.
    */
-  Bounds get(const Node& v) const;
-
-  /** Return the current variable bounds as an interval assignment. */
-  const std::map<Node, Bounds>& get() const;
-
-  /**
-   * Add a new theory atom. Return true if the theory atom induces a new
-   * variable bound.
-   */
-  bool add(const Node& n);
-};
+  class BoundInference
+  {
+   public:
+    void reset();
+
+    /**
+     * Get the current interval for lhs. Creates a new (full) interval if
+     * necessary.
+     */
+    Bounds& get_or_add(const Node& lhs);
+    /**
+     * Get the current interval for lhs. Returns a full interval if no interval
+     * was derived yet.
+     */
+    Bounds get(const Node& lhs) const;
+
+    /** Return the current term bounds as an interval assignment. */
+    const std::map<Node, Bounds>& get() const;
+
+    /**
+     * Add a new theory atom. Return true if the theory atom induces a new
+     * term bound.
+     * If onlyVariables is true, the left hand side needs to be a single
+     * variable to induce a bound.
+     */
+    bool add(const Node& n, bool onlyVariables = true);
+
+    /**
+     * Post-processes a set of nodes and replaces bounds by their origins.
+     * This utility sometimes creates new bounds, either due to tightening of
+     * integer terms or because an equality was derived from two weak
+     * inequalities. While the origins of these new bounds are recorded in
+     * lower_origin and upper_origin, this method can be used to conveniently
+     * replace these new nodes by their origins.
+     * This can be used, for example, when constructing conflicts.
+     */
+    void replaceByOrigins(std::vector<Node>& nodes) const;
+
+   private:
+    /** The currently strictest bounds for every lhs. */
+    std::map<Node, Bounds> d_bounds;
+
+    /** Updates the lower bound for the given lhs */
+    void update_lower_bound(const Node& origin,
+                            const Node& lhs,
+                            const Node& value,
+                            bool strict);
+    /** Updates the upper bound for the given lhs */
+    void update_upper_bound(const Node& origin,
+                            const Node& lhs,
+                            const Node& value,
+                            bool strict);
+  };
 
 /** Print the current variable bounds. */
 std::ostream& operator<<(std::ostream& os, const BoundInference& bi);
index 656b5ed0dde8c9d46011015808a43536075c3b9e..43359c4602519d6474aa752355f9b74d59c3ca0b 100644 (file)
@@ -91,13 +91,6 @@ void InferenceManager::clearWaitingLemmas()
   d_waitingLem.clear();
 }
 
-void InferenceManager::addConflict(const Node& conf, InferenceId inftype)
-{
-  Trace("arith::infman") << "Adding conflict: " << inftype << " " << conf
-                         << std::endl;
-  conflict(conf);
-}
-
 bool InferenceManager::hasUsed() const
 {
   return hasSent() || hasPending();
index 9228add196603e110acbe6b9936d35f7f5ec3b73..f2784ed89ace9990510c9244b1f6d44892e2aa1a 100644 (file)
@@ -83,9 +83,6 @@ class InferenceManager : public InferenceManagerBuffered
    */
   void clearWaitingLemmas();
 
-  /** Add a conflict to the this inference manager. */
-  void addConflict(const Node& conf, InferenceId inftype);
-
   /**
    * Checks whether we have made any progress, that is whether a conflict, lemma
    * or fact was added or whether a lemma or fact is pending.
index d12a861ac6ff86c93ae8374c31385d9973d4c42e..831530995397453706344f1510eac4315d7c0786 100644 (file)
@@ -84,8 +84,12 @@ void CadSolver::checkFull()
     Trace("nl-cad") << "Collected MIS: " << mis << std::endl;
     Assert(!mis.empty()) << "Infeasible subset can not be empty";
     Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl;
-    d_im.addConflict(NodeManager::currentNM()->mkAnd(mis),
-                     InferenceId::NL_CAD_CONFLICT);
+    for (auto& n : mis)
+    {
+      n = n.negate();
+    }
+    d_im.addPendingArithLemma(NodeManager::currentNM()->mkOr(mis),
+                              InferenceId::NL_CAD_CONFLICT);
   }
 #else
   Warning() << "Tried to use CadSolver but libpoly is not available. Compile "
index 1e8f0769a91136c6cfbd54596f0b0416c4ffc5da..779c000b76e0531cee4919e935606ec69670fb15 100644 (file)
@@ -67,7 +67,8 @@ void ContractionOriginManager::add(const Node& targetVariable,
   d_currentOrigins[targetVariable] = d_allocations.back().get();
 }
 
-Node ContractionOriginManager::getOrigins(const Node& variable) const
+std::vector<Node> ContractionOriginManager::getOrigins(
+    const Node& variable) const
 {
   Trace("nl-icp") << "Obtaining origins for " << variable << std::endl;
   std::set<Node> origins;
@@ -75,12 +76,7 @@ Node ContractionOriginManager::getOrigins(const Node& variable) const
       << "Using variable as origin that is unknown yet.";
   getOrigins(d_currentOrigins.at(variable), origins);
   Assert(!origins.empty()) << "There should be at least one origin";
-  if (origins.size() == 1)
-  {
-    return *origins.begin();
-  }
-  return NodeManager::currentNM()->mkNode(
-      Kind::AND, std::vector<Node>(origins.begin(), origins.end()));
+  return std::vector<Node>(origins.begin(), origins.end());
 }
 
 bool ContractionOriginManager::isInOrigins(const Node& variable,
index d8e56759dbe6dcbb968b16e3479237d82c557581..885fc740a48c2ffe08c0079b523adc613a5c62f1 100644 (file)
@@ -80,7 +80,7 @@ class ContractionOriginManager
   /**
    * Collect all theory atoms from the origins of the given variable.
    */
-  Node getOrigins(const Node& variable) const;
+  std::vector<Node> getOrigins(const Node& variable) const;
 
   /** Check whether a node c is among the origins of a variable. */
   bool isInOrigins(const Node& variable, const Node& c) const;
index 4ec33c3609bb942cf99c86e6c2369382776b1a37..b4cb54216a79147dc926f1c54cef2170adccb273 100644 (file)
@@ -107,7 +107,7 @@ std::vector<Candidate> ICPSolver::constructCandidates(const Node& n)
     if (isolated == 1)
     {
       poly::Variable lhs = d_mapper(v);
-      poly::SignCondition rel;
+      poly::SignCondition rel = poly::SignCondition::EQ;
       switch (k)
       {
         case Kind::LT: rel = poly::SignCondition::LT; break;
@@ -133,7 +133,7 @@ std::vector<Candidate> ICPSolver::constructCandidates(const Node& n)
     else if (isolated == -1)
     {
       poly::Variable lhs = d_mapper(v);
-      poly::SignCondition rel;
+      poly::SignCondition rel = poly::SignCondition::EQ;
       switch (k)
       {
         case Kind::LT: rel = poly::SignCondition::GT; break;
@@ -210,7 +210,7 @@ PropagationResult ICPSolver::doPropagationRound()
     Trace("nl-icp") << "ICP budget exceeded" << std::endl;
     return PropagationResult::NOT_CHANGED;
   }
-  d_state.d_conflict = Node();
+  d_state.d_conflict.clear();
   Trace("nl-icp") << "Starting propagation with "
                   << IAWrapper{d_state.d_assignment, d_mapper} << std::endl;
   Trace("nl-icp") << "Current budget: " << d_budget << std::endl;
@@ -267,7 +267,7 @@ std::vector<Node> ICPSolver::generateLemmas() const
       Node c = nm->mkNode(rel, v, value_to_node(get_lower(i), v));
       if (!d_state.d_origins.isInOrigins(v, c))
       {
-        Node premise = d_state.d_origins.getOrigins(v);
+        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
         Trace("nl-icp") << premise << " => " << c << std::endl;
         Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
         if (lemma.isConst())
@@ -287,7 +287,7 @@ std::vector<Node> ICPSolver::generateLemmas() const
       Node c = nm->mkNode(rel, v, value_to_node(get_upper(i), v));
       if (!d_state.d_origins.isInOrigins(v, c))
       {
-        Node premise = d_state.d_origins.getOrigins(v);
+        Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v));
         Trace("nl-icp") << premise << " => " << c << std::endl;
         Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c));
         if (lemma.isConst())
@@ -343,7 +343,13 @@ void ICPSolver::check()
         Trace("nl-icp") << "Found a conflict: " << d_state.d_conflict
                         << std::endl;
 
-        d_im.addConflict(d_state.d_conflict, InferenceId::NL_ICP_CONFLICT);
+        std::vector<Node> mis;
+        for (const auto& n : d_state.d_conflict)
+        {
+          mis.emplace_back(n.negate());
+        }
+        d_im.addPendingArithLemma(NodeManager::currentNM()->mkOr(mis),
+                                  InferenceId::NL_ICP_CONFLICT);
         did_progress = true;
         progress = false;
         break;
index ca2aef10a5cc56f16621feaecb8544c599b501f2..32861c6410dad24356b2cb4d60825938c862de26 100644 (file)
@@ -67,8 +67,8 @@ class ICPSolver
     poly::IntervalAssignment d_assignment;
     /** The origins for the current assignment */
     ContractionOriginManager d_origins;
-    /** The conflict, if any way found. Initially the null node */
-    Node d_conflict;
+    /** The conflict, if any way found. Initially empty */
+    std::vector<Node> d_conflict;
 
     /** Initialized the variable bounds with a variable mapper */
     ICPState(VariableMapper& vm) {}
@@ -80,7 +80,7 @@ class ICPSolver
       d_candidates.clear();
       d_assignment.clear();
       d_origins = ContractionOriginManager();
-      d_conflict = Node();
+      d_conflict.clear();
     }
   };
 
index 76f37213afb28933679ba16133f1c0782c2c76d7..fdab6d7b73d1aaecd9bc624d7186b01325b3406c 100644 (file)
@@ -21,6 +21,7 @@
 #include "options/theory_options.h"
 #include "theory/arith/arith_state.h"
 #include "theory/arith/arith_utilities.h"
+#include "theory/arith/bound_inference.h"
 #include "theory/arith/theory_arith.h"
 #include "theory/ext_theory.h"
 #include "theory/theory_model.h"
@@ -179,16 +180,15 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
   }
   Valuation v = d_containing.getValuation();
   NodeManager* nm = NodeManager::currentNM();
-  // get the assertions
-  std::map<Node, Rational> init_bounds[2];
-  std::map<Node, Node> init_bounds_lit[2];
-  unsigned nassertions = 0;
+
+  BoundInference bounds;
+
   std::unordered_set<Node, NodeHashFunction> init_assertions;
+
   for (Theory::assertions_iterator it = d_containing.facts_begin();
        it != d_containing.facts_end();
        ++it)
   {
-    nassertions++;
     const Assertion& assertion = *it;
     Trace("nl-ext") << "Loaded " << assertion.d_assertion << " from theory"
                     << std::endl;
@@ -198,97 +198,23 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
       // not relevant, skip
       continue;
     }
-    init_assertions.insert(lit);
-    // check for concrete bounds
-    bool pol = lit.getKind() != NOT;
-    Node atom_orig = lit.getKind() == NOT ? lit[0] : lit;
-
-    std::vector<Node> atoms;
-    if (atom_orig.getKind() == EQUAL)
-    {
-      if (pol)
-      {
-        // t = s  is ( t >= s ^ t <= s )
-        for (unsigned i = 0; i < 2; i++)
-        {
-          Node atom_new = nm->mkNode(GEQ, atom_orig[i], atom_orig[1 - i]);
-          atom_new = Rewriter::rewrite(atom_new);
-          atoms.push_back(atom_new);
-        }
-      }
-    }
-    else
+    if (bounds.add(lit, false))
     {
-      atoms.push_back(atom_orig);
+      continue;
     }
+    init_assertions.insert(lit);
+  }
 
-    for (const Node& atom : atoms)
+  for (const auto& vb : bounds.get())
+  {
+    const Bounds& b = vb.second;
+    if (!b.lower_bound.isNull())
     {
-      // non-strict bounds only
-      if (atom.getKind() == GEQ || (!pol && atom.getKind() == GT))
-      {
-        Node p = atom[0];
-        Assert(atom[1].isConst());
-        Rational bound = atom[1].getConst<Rational>();
-        if (!pol)
-        {
-          if (atom[0].getType().isInteger())
-          {
-            // ~( p >= c ) ---> ( p <= c-1 )
-            bound = bound - Rational(1);
-          }
-        }
-        unsigned bindex = pol ? 0 : 1;
-        bool setBound = true;
-        std::map<Node, Rational>::iterator itb = init_bounds[bindex].find(p);
-        if (itb != init_bounds[bindex].end())
-        {
-          if (itb->second == bound)
-          {
-            setBound = atom_orig.getKind() == EQUAL;
-          }
-          else
-          {
-            setBound = pol ? itb->second < bound : itb->second > bound;
-          }
-          if (setBound)
-          {
-            // the bound is subsumed
-            init_assertions.erase(init_bounds_lit[bindex][p]);
-          }
-        }
-        if (setBound)
-        {
-          Trace("nl-ext-init") << (pol ? "Lower" : "Upper") << " bound for "
-                               << p << " : " << bound << std::endl;
-          init_bounds[bindex][p] = bound;
-          init_bounds_lit[bindex][p] = lit;
-        }
-      }
+      init_assertions.insert(b.lower_bound);
     }
-  }
-  // for each bound that is the same, ensure we've inferred the equality
-  for (std::pair<const Node, Rational>& ib : init_bounds[0])
-  {
-    Node p = ib.first;
-    Node lit1 = init_bounds_lit[0][p];
-    if (lit1.getKind() != EQUAL)
+    if (!b.upper_bound.isNull())
     {
-      std::map<Node, Rational>::iterator itb = init_bounds[1].find(p);
-      if (itb != init_bounds[1].end())
-      {
-        if (ib.second == itb->second)
-        {
-          Node eq = p.eqNode(nm->mkConst(ib.second));
-          eq = Rewriter::rewrite(eq);
-          Node lit2 = init_bounds_lit[1][p];
-          Assert(lit2.getKind() != EQUAL);
-          // use the equality instead, thus these are redundant
-          init_assertions.erase(lit1);
-          init_assertions.erase(lit2);
-          init_assertions.insert(eq);
-        }
-      }
+      init_assertions.insert(b.upper_bound);
     }
   }
 
@@ -301,6 +227,7 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
     auto iait = init_assertions.find(lit);
     if (iait != init_assertions.end())
     {
+      Trace("nl-ext") << "Adding " << lit << std::endl;
       assertions.push_back(lit);
       init_assertions.erase(iait);
     }
@@ -309,10 +236,12 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
   // function by the code above.
   for (const Node& a : init_assertions)
   {
+    Trace("nl-ext") << "Adding " << a << std::endl;
     assertions.push_back(a);
   }
-  Trace("nl-ext") << "...keep " << assertions.size() << " / " << nassertions
-                  << " assertions." << std::endl;
+  Trace("nl-ext") << "...keep " << assertions.size() << " / "
+                  << d_containing.numAssertions() << " assertions."
+                  << std::endl;
 }
 
 std::vector<Node> NonlinearExtension::checkModelEval(
index 2f4586d78eafb1cc6b7cd5bb126b0e4a7b5b601c..bd30422313a5d89c75465160ccffb409ee3427e9 100644 (file)
@@ -249,6 +249,7 @@ class NonlinearExtension
    * and for establishing when we are able to answer "SAT".
    */
   NlModel d_model;
+
   /** The transcendental extension object
    *
    * This is the subsolver responsible for running the procedure for
index a76a781c49f81555bf4ddc8b4f9a6b03021e0097..0e4e21b76705eac3ceaf45be42c834a4a86a35a5 100644 (file)
@@ -785,12 +785,12 @@ poly::IntervalAssignment getBounds(VariableMapper& vm, const BoundInference& bi)
   for (const auto& vb : bi.get())
   {
     poly::Variable v = vm(vb.first);
-    poly::Value l = vb.second.lower.isNull()
+    poly::Value l = vb.second.lower_value.isNull()
                         ? poly::Value::minus_infty()
-                        : node_to_value(vb.second.lower, vb.first);
-    poly::Value u = vb.second.upper.isNull()
+                        : node_to_value(vb.second.lower_value, vb.first);
+    poly::Value u = vb.second.upper_value.isNull()
                         ? poly::Value::plus_infty()
-                        : node_to_value(vb.second.upper, vb.first);
+                        : node_to_value(vb.second.upper_value, vb.first);
     poly::Interval i(l, vb.second.lower_strict, u, vb.second.upper_strict);
     res.set(v, i);
   }