From 42fe4ae0e866d78dd7743214eb1b1ccb92900c5a Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Fri, 30 Oct 2020 09:07:53 +0100 Subject: [PATCH] Use BoundInference in nonlinear extension (#5359) Currently the NonlinearExtensions uses a custom logic to eliminate redundant bounds and perform tightening on bound integer terms. As these replacements are not recorded, incorrect conflicts are being sent to the InferenceManager. This PR replaces this logic by the BoundInference class and fixes the issues with conflicts by - allowing BoundInference to collect bounds on arbitrary left hand sides (instead of only variables), - improving origin tracking in BoundInference by explicitly constructing the new bound constraints, - adding tightening of integer bounds, - emitting lemmas instead of conflicts, and finally - replacing the current logic by using the BoundInference class. --- src/theory/arith/bound_inference.cpp | 244 ++++++++++++------ src/theory/arith/bound_inference.h | 112 ++++---- src/theory/arith/inference_manager.cpp | 7 - src/theory/arith/inference_manager.h | 3 - src/theory/arith/nl/cad_solver.cpp | 8 +- .../arith/nl/icp/contraction_origins.cpp | 10 +- src/theory/arith/nl/icp/contraction_origins.h | 2 +- src/theory/arith/nl/icp/icp_solver.cpp | 18 +- src/theory/arith/nl/icp/icp_solver.h | 6 +- src/theory/arith/nl/nonlinear_extension.cpp | 113 ++------ src/theory/arith/nl/nonlinear_extension.h | 1 + src/theory/arith/nl/poly_conversion.cpp | 8 +- 12 files changed, 281 insertions(+), 251 deletions(-) diff --git a/src/theory/arith/bound_inference.cpp b/src/theory/arith/bound_inference.cpp index 92e71bf14..c8a9527c7 100644 --- a/src/theory/arith/bound_inference.cpp +++ b/src/theory/arith/bound_inference.cpp @@ -22,69 +22,24 @@ namespace theory { namespace arith { std::ostream& operator<<(std::ostream& os, const Bounds& b) { - - return os << (b.lower_strict ? '(' : '[') << b.lower << " .. " << b.upper - << (b.upper_strict ? ')' : ']'); - -} - -void BoundInference::update_lower_bound(const Node& origin, - const Node& variable, - const Node& value, - bool strict) -{ - // variable > or >= value because of origin - Trace("nl-icp") << "\tNew bound " << variable << (strict ? ">" : ">=") - << value << " due to " << origin << std::endl; - Bounds& b = get_or_add(variable); - if (b.lower.isNull() || b.lower.getConst() < value.getConst()) - { - b.lower = value; - b.lower_strict = strict; - b.lower_origin = origin; - } - else if (strict && b.lower == value) - { - b.lower_strict = strict; - b.lower_origin = origin; - } -} -void BoundInference::update_upper_bound(const Node& origin, - const Node& variable, - const Node& value, - bool strict) -{ - // variable < or <= value because of origin - Trace("nl-icp") << "\tNew bound " << variable << (strict ? "<" : "<=") - << value << " due to " << origin << std::endl; - Bounds& b = get_or_add(variable); - if (b.upper.isNull() || b.upper.getConst() > value.getConst()) - { - b.upper = value; - b.upper_strict = strict; - b.upper_origin = origin; - } - else if (strict && b.upper == value) - { - b.upper_strict = strict; - b.upper_origin = origin; - } + return os << (b.lower_strict ? '(' : '[') << b.lower_value << " .. " + << b.upper_value << (b.upper_strict ? ')' : ']'); } void BoundInference::reset() { d_bounds.clear(); } -Bounds& BoundInference::get_or_add(const Node& v) +Bounds& BoundInference::get_or_add(const Node& lhs) { - auto it = d_bounds.find(v); + auto it = d_bounds.find(lhs); if (it == d_bounds.end()) { - it = d_bounds.emplace(v, Bounds()).first; + it = d_bounds.emplace(lhs, Bounds()).first; } return it->second; } -Bounds BoundInference::get(const Node& v) const +Bounds BoundInference::get(const Node& lhs) const { - auto it = d_bounds.find(v); + auto it = d_bounds.find(lhs); if (it == d_bounds.end()) { return Bounds{}; @@ -93,7 +48,7 @@ Bounds BoundInference::get(const Node& v) const } const std::map& BoundInference::get() const { return d_bounds; } -bool BoundInference::add(const Node& n) +bool BoundInference::add(const Node& n, bool onlyVariables) { Node tmp = Rewriter::rewrite(n); if (tmp.getKind() == Kind::CONST_BOOLEAN) @@ -103,46 +58,175 @@ bool BoundInference::add(const Node& n) // Parse the node as a comparison auto comp = Comparison::parseNormalForm(tmp); auto dec = comp.decompose(true); - if (std::get<0>(dec).isVariable()) + if (onlyVariables && !std::get<0>(dec).isVariable()) { - Variable v = std::get<0>(dec).getVariable(); - Kind relation = std::get<1>(dec); - if (relation == Kind::DISTINCT) return false; - Constant bound = std::get<2>(dec); - // has the form v ~relation~ bound + return false; + } + + Node lhs = std::get<0>(dec).getNode(); + Kind relation = std::get<1>(dec); + if (relation == Kind::DISTINCT) return false; + Node bound = std::get<2>(dec).getNode(); + // has the form lhs ~relation~ bound + if (lhs.getType().isInteger()) + { + Rational br = bound.getConst(); + auto* nm = NodeManager::currentNM(); switch (relation) { - case Kind::LEQ: - update_upper_bound(n, v.getNode(), bound.getNode(), false); - break; + case Kind::LEQ: bound = nm->mkConst(br.floor()); break; case Kind::LT: - update_upper_bound(n, v.getNode(), bound.getNode(), true); - break; - case Kind::EQUAL: - update_lower_bound(n, v.getNode(), bound.getNode(), false); - update_upper_bound(n, v.getNode(), bound.getNode(), false); + bound = nm->mkConst((br - 1).ceiling()); + relation = Kind::LEQ; break; case Kind::GT: - update_lower_bound(n, v.getNode(), bound.getNode(), true); + bound = nm->mkConst((br + 1).floor()); + relation = Kind::GEQ; break; - case Kind::GEQ: - update_lower_bound(n, v.getNode(), bound.getNode(), false); - break; - default: Assert(false); + case Kind::GEQ: bound = nm->mkConst(br.ceiling()); break; + default:; + } + Trace("bound-inf") << "Strengthened " << n << " to " << lhs << " " + << relation << " " << bound << std::endl; + } + + switch (relation) + { + case Kind::LEQ: update_upper_bound(n, lhs, bound, false); break; + case Kind::LT: update_upper_bound(n, lhs, bound, true); break; + case Kind::EQUAL: + update_lower_bound(n, lhs, bound, false); + update_upper_bound(n, lhs, bound, false); + break; + case Kind::GT: update_lower_bound(n, lhs, bound, true); break; + case Kind::GEQ: update_lower_bound(n, lhs, bound, false); break; + default: Assert(false); + } + return true; +} + +void BoundInference::replaceByOrigins(std::vector& nodes) const +{ + std::vector toAdd; + for (auto& n : nodes) + { + for (const auto& b : d_bounds) + { + if (n == b.second.lower_bound && n == b.second.upper_bound) + { + if (n != b.second.lower_origin && n != b.second.upper_origin) + { + Trace("bound-inf") + << "Replace " << n << " by origins " << b.second.lower_origin + << " and " << b.second.upper_origin << std::endl; + n = b.second.lower_origin; + toAdd.emplace_back(b.second.upper_origin); + } + } + else if (n == b.second.lower_bound) + { + if (n != b.second.lower_origin) + { + Trace("bound-inf") << "Replace " << n << " by origin " + << b.second.lower_origin << std::endl; + n = b.second.lower_origin; + } + } + else if (n == b.second.upper_bound) + { + if (n != b.second.upper_origin) + { + Trace("bound-inf") << "Replace " << n << " by origin " + << b.second.upper_origin << std::endl; + n = b.second.upper_origin; + } + } + } + } + nodes.insert(nodes.end(), toAdd.begin(), toAdd.end()); +} + +void BoundInference::update_lower_bound(const Node& origin, + const Node& lhs, + const Node& value, + bool strict) +{ + // lhs > or >= value because of origin + Trace("bound-inf") << "\tNew bound " << lhs << (strict ? ">" : ">=") << value + << " due to " << origin << std::endl; + Bounds& b = get_or_add(lhs); + if (b.lower_value.isNull() + || b.lower_value.getConst() < value.getConst()) + { + auto* nm = NodeManager::currentNM(); + b.lower_value = value; + b.lower_strict = strict; + + b.lower_origin = origin; + + if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) + { + b.lower_bound = b.upper_bound = + Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); + } + else + { + b.lower_bound = Rewriter::rewrite( + nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value)); + } + } + else if (strict && b.lower_value == value) + { + auto* nm = NodeManager::currentNM(); + b.lower_strict = strict; + b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value)); + b.lower_origin = origin; + } +} +void BoundInference::update_upper_bound(const Node& origin, + const Node& lhs, + const Node& value, + bool strict) +{ + // lhs < or <= value because of origin + Trace("bound-inf") << "\tNew bound " << lhs << (strict ? "<" : "<=") << value + << " due to " << origin << std::endl; + Bounds& b = get_or_add(lhs); + if (b.upper_value.isNull() + || b.upper_value.getConst() > value.getConst()) + { + auto* nm = NodeManager::currentNM(); + b.upper_value = value; + b.upper_strict = strict; + b.upper_origin = origin; + if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) + { + b.lower_bound = b.upper_bound = + Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); + } + else + { + b.upper_bound = Rewriter::rewrite( + nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value)); } - return true; } - return false; + else if (strict && b.upper_value == value) + { + auto* nm = NodeManager::currentNM(); + b.upper_strict = strict; + b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value)); + b.upper_origin = origin; + } } std::ostream& operator<<(std::ostream& os, const BoundInference& bi) { os << "Bounds:" << std::endl; - for (const auto& vb : bi.get()) + for (const auto& b : bi.get()) { - os << "\t" << vb.first << " -> " << vb.second.lower << ".." - << vb.second.upper << std::endl; + os << "\t" << b.first << " -> " << b.second.lower_value << ".." + << b.second.upper_value << std::endl; } return os; } @@ -153,8 +237,10 @@ std::map> getBounds(const std::vector& assertio bi.add(a); } std::map> res; - for (const auto& vb: bi.get()) { - res.emplace(vb.first, std::make_pair(vb.second.lower, vb.second.upper)); + for (const auto& b : bi.get()) + { + res.emplace(b.first, + std::make_pair(b.second.lower_value, b.second.upper_value)); } return res; } diff --git a/src/theory/arith/bound_inference.h b/src/theory/arith/bound_inference.h index b360ad421..174ba3a0f 100644 --- a/src/theory/arith/bound_inference.h +++ b/src/theory/arith/bound_inference.h @@ -27,66 +27,84 @@ namespace arith { struct Bounds { - /** The lower bound */ - Node lower; + /** The lower bound value */ + Node lower_value; /** Whether the lower bound is strict or weak */ bool lower_strict = true; + /** The lower bound as constraint */ + Node lower_bound; /** The origin of the lower bound */ Node lower_origin; - /** The upper bound */ - Node upper; + /** The upper bound value */ + Node upper_value; /** Whether the upper bound is strict or weak */ bool upper_strict = true; + /** The upper bound as constraint */ + Node upper_bound; /** The origin of the upper bound */ Node upper_origin; }; -/** Print the current variable bounds. */ -std::ostream& operator<<(std::ostream& os, const Bounds& b); - -/** - * A utility class that extracts direct bounds on single variables from theory - * atoms. - */ -class BoundInference -{ - /** The currently strictest bounds for every variable. */ - std::map d_bounds; - - /** Updates the lower bound for the given variable */ - void update_lower_bound(const Node& origin, - const Node& variable, - const Node& value, - bool strict); - /** Updates the upper bound for the given variable */ - void update_upper_bound(const Node& origin, - const Node& variable, - const Node& value, - bool strict); - - public: - void reset(); + /** Print the current bounds. */ + std::ostream& operator<<(std::ostream& os, const Bounds& b); /** - * Get the current interval for v. Creates a new (full) interval if - * necessary. - */ - Bounds& get_or_add(const Node& v); - /** - * Get the current interval for v. Returns a full interval if no interval was - * derived yet. + * A utility class that extracts direct bounds on arithmetic terms from theory + * atoms. */ - Bounds get(const Node& v) const; - - /** Return the current variable bounds as an interval assignment. */ - const std::map& get() const; - - /** - * Add a new theory atom. Return true if the theory atom induces a new - * variable bound. - */ - bool add(const Node& n); -}; + class BoundInference + { + public: + void reset(); + + /** + * Get the current interval for lhs. Creates a new (full) interval if + * necessary. + */ + Bounds& get_or_add(const Node& lhs); + /** + * Get the current interval for lhs. Returns a full interval if no interval + * was derived yet. + */ + Bounds get(const Node& lhs) const; + + /** Return the current term bounds as an interval assignment. */ + const std::map& get() const; + + /** + * Add a new theory atom. Return true if the theory atom induces a new + * term bound. + * If onlyVariables is true, the left hand side needs to be a single + * variable to induce a bound. + */ + bool add(const Node& n, bool onlyVariables = true); + + /** + * Post-processes a set of nodes and replaces bounds by their origins. + * This utility sometimes creates new bounds, either due to tightening of + * integer terms or because an equality was derived from two weak + * inequalities. While the origins of these new bounds are recorded in + * lower_origin and upper_origin, this method can be used to conveniently + * replace these new nodes by their origins. + * This can be used, for example, when constructing conflicts. + */ + void replaceByOrigins(std::vector& nodes) const; + + private: + /** The currently strictest bounds for every lhs. */ + std::map d_bounds; + + /** Updates the lower bound for the given lhs */ + void update_lower_bound(const Node& origin, + const Node& lhs, + const Node& value, + bool strict); + /** Updates the upper bound for the given lhs */ + void update_upper_bound(const Node& origin, + const Node& lhs, + const Node& value, + bool strict); + }; /** Print the current variable bounds. */ std::ostream& operator<<(std::ostream& os, const BoundInference& bi); diff --git a/src/theory/arith/inference_manager.cpp b/src/theory/arith/inference_manager.cpp index 656b5ed0d..43359c460 100644 --- a/src/theory/arith/inference_manager.cpp +++ b/src/theory/arith/inference_manager.cpp @@ -91,13 +91,6 @@ void InferenceManager::clearWaitingLemmas() d_waitingLem.clear(); } -void InferenceManager::addConflict(const Node& conf, InferenceId inftype) -{ - Trace("arith::infman") << "Adding conflict: " << inftype << " " << conf - << std::endl; - conflict(conf); -} - bool InferenceManager::hasUsed() const { return hasSent() || hasPending(); diff --git a/src/theory/arith/inference_manager.h b/src/theory/arith/inference_manager.h index 9228add19..f2784ed89 100644 --- a/src/theory/arith/inference_manager.h +++ b/src/theory/arith/inference_manager.h @@ -83,9 +83,6 @@ class InferenceManager : public InferenceManagerBuffered */ void clearWaitingLemmas(); - /** Add a conflict to the this inference manager. */ - void addConflict(const Node& conf, InferenceId inftype); - /** * Checks whether we have made any progress, that is whether a conflict, lemma * or fact was added or whether a lemma or fact is pending. diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp index d12a861ac..831530995 100644 --- a/src/theory/arith/nl/cad_solver.cpp +++ b/src/theory/arith/nl/cad_solver.cpp @@ -84,8 +84,12 @@ void CadSolver::checkFull() Trace("nl-cad") << "Collected MIS: " << mis << std::endl; Assert(!mis.empty()) << "Infeasible subset can not be empty"; Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl; - d_im.addConflict(NodeManager::currentNM()->mkAnd(mis), - InferenceId::NL_CAD_CONFLICT); + for (auto& n : mis) + { + n = n.negate(); + } + d_im.addPendingArithLemma(NodeManager::currentNM()->mkOr(mis), + InferenceId::NL_CAD_CONFLICT); } #else Warning() << "Tried to use CadSolver but libpoly is not available. Compile " diff --git a/src/theory/arith/nl/icp/contraction_origins.cpp b/src/theory/arith/nl/icp/contraction_origins.cpp index 1e8f0769a..779c000b7 100644 --- a/src/theory/arith/nl/icp/contraction_origins.cpp +++ b/src/theory/arith/nl/icp/contraction_origins.cpp @@ -67,7 +67,8 @@ void ContractionOriginManager::add(const Node& targetVariable, d_currentOrigins[targetVariable] = d_allocations.back().get(); } -Node ContractionOriginManager::getOrigins(const Node& variable) const +std::vector ContractionOriginManager::getOrigins( + const Node& variable) const { Trace("nl-icp") << "Obtaining origins for " << variable << std::endl; std::set origins; @@ -75,12 +76,7 @@ Node ContractionOriginManager::getOrigins(const Node& variable) const << "Using variable as origin that is unknown yet."; getOrigins(d_currentOrigins.at(variable), origins); Assert(!origins.empty()) << "There should be at least one origin"; - if (origins.size() == 1) - { - return *origins.begin(); - } - return NodeManager::currentNM()->mkNode( - Kind::AND, std::vector(origins.begin(), origins.end())); + return std::vector(origins.begin(), origins.end()); } bool ContractionOriginManager::isInOrigins(const Node& variable, diff --git a/src/theory/arith/nl/icp/contraction_origins.h b/src/theory/arith/nl/icp/contraction_origins.h index d8e56759d..885fc740a 100644 --- a/src/theory/arith/nl/icp/contraction_origins.h +++ b/src/theory/arith/nl/icp/contraction_origins.h @@ -80,7 +80,7 @@ class ContractionOriginManager /** * Collect all theory atoms from the origins of the given variable. */ - Node getOrigins(const Node& variable) const; + std::vector getOrigins(const Node& variable) const; /** Check whether a node c is among the origins of a variable. */ bool isInOrigins(const Node& variable, const Node& c) const; diff --git a/src/theory/arith/nl/icp/icp_solver.cpp b/src/theory/arith/nl/icp/icp_solver.cpp index 4ec33c360..b4cb54216 100644 --- a/src/theory/arith/nl/icp/icp_solver.cpp +++ b/src/theory/arith/nl/icp/icp_solver.cpp @@ -107,7 +107,7 @@ std::vector ICPSolver::constructCandidates(const Node& n) if (isolated == 1) { poly::Variable lhs = d_mapper(v); - poly::SignCondition rel; + poly::SignCondition rel = poly::SignCondition::EQ; switch (k) { case Kind::LT: rel = poly::SignCondition::LT; break; @@ -133,7 +133,7 @@ std::vector ICPSolver::constructCandidates(const Node& n) else if (isolated == -1) { poly::Variable lhs = d_mapper(v); - poly::SignCondition rel; + poly::SignCondition rel = poly::SignCondition::EQ; switch (k) { case Kind::LT: rel = poly::SignCondition::GT; break; @@ -210,7 +210,7 @@ PropagationResult ICPSolver::doPropagationRound() Trace("nl-icp") << "ICP budget exceeded" << std::endl; return PropagationResult::NOT_CHANGED; } - d_state.d_conflict = Node(); + d_state.d_conflict.clear(); Trace("nl-icp") << "Starting propagation with " << IAWrapper{d_state.d_assignment, d_mapper} << std::endl; Trace("nl-icp") << "Current budget: " << d_budget << std::endl; @@ -267,7 +267,7 @@ std::vector ICPSolver::generateLemmas() const Node c = nm->mkNode(rel, v, value_to_node(get_lower(i), v)); if (!d_state.d_origins.isInOrigins(v, c)) { - Node premise = d_state.d_origins.getOrigins(v); + Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v)); Trace("nl-icp") << premise << " => " << c << std::endl; Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c)); if (lemma.isConst()) @@ -287,7 +287,7 @@ std::vector ICPSolver::generateLemmas() const Node c = nm->mkNode(rel, v, value_to_node(get_upper(i), v)); if (!d_state.d_origins.isInOrigins(v, c)) { - Node premise = d_state.d_origins.getOrigins(v); + Node premise = nm->mkAnd(d_state.d_origins.getOrigins(v)); Trace("nl-icp") << premise << " => " << c << std::endl; Node lemma = Rewriter::rewrite(nm->mkNode(Kind::IMPLIES, premise, c)); if (lemma.isConst()) @@ -343,7 +343,13 @@ void ICPSolver::check() Trace("nl-icp") << "Found a conflict: " << d_state.d_conflict << std::endl; - d_im.addConflict(d_state.d_conflict, InferenceId::NL_ICP_CONFLICT); + std::vector mis; + for (const auto& n : d_state.d_conflict) + { + mis.emplace_back(n.negate()); + } + d_im.addPendingArithLemma(NodeManager::currentNM()->mkOr(mis), + InferenceId::NL_ICP_CONFLICT); did_progress = true; progress = false; break; diff --git a/src/theory/arith/nl/icp/icp_solver.h b/src/theory/arith/nl/icp/icp_solver.h index ca2aef10a..32861c641 100644 --- a/src/theory/arith/nl/icp/icp_solver.h +++ b/src/theory/arith/nl/icp/icp_solver.h @@ -67,8 +67,8 @@ class ICPSolver poly::IntervalAssignment d_assignment; /** The origins for the current assignment */ ContractionOriginManager d_origins; - /** The conflict, if any way found. Initially the null node */ - Node d_conflict; + /** The conflict, if any way found. Initially empty */ + std::vector d_conflict; /** Initialized the variable bounds with a variable mapper */ ICPState(VariableMapper& vm) {} @@ -80,7 +80,7 @@ class ICPSolver d_candidates.clear(); d_assignment.clear(); d_origins = ContractionOriginManager(); - d_conflict = Node(); + d_conflict.clear(); } }; diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index 76f37213a..fdab6d7b7 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -21,6 +21,7 @@ #include "options/theory_options.h" #include "theory/arith/arith_state.h" #include "theory/arith/arith_utilities.h" +#include "theory/arith/bound_inference.h" #include "theory/arith/theory_arith.h" #include "theory/ext_theory.h" #include "theory/theory_model.h" @@ -179,16 +180,15 @@ void NonlinearExtension::getAssertions(std::vector& assertions) } Valuation v = d_containing.getValuation(); NodeManager* nm = NodeManager::currentNM(); - // get the assertions - std::map init_bounds[2]; - std::map init_bounds_lit[2]; - unsigned nassertions = 0; + + BoundInference bounds; + std::unordered_set init_assertions; + for (Theory::assertions_iterator it = d_containing.facts_begin(); it != d_containing.facts_end(); ++it) { - nassertions++; const Assertion& assertion = *it; Trace("nl-ext") << "Loaded " << assertion.d_assertion << " from theory" << std::endl; @@ -198,97 +198,23 @@ void NonlinearExtension::getAssertions(std::vector& assertions) // not relevant, skip continue; } - init_assertions.insert(lit); - // check for concrete bounds - bool pol = lit.getKind() != NOT; - Node atom_orig = lit.getKind() == NOT ? lit[0] : lit; - - std::vector atoms; - if (atom_orig.getKind() == EQUAL) - { - if (pol) - { - // t = s is ( t >= s ^ t <= s ) - for (unsigned i = 0; i < 2; i++) - { - Node atom_new = nm->mkNode(GEQ, atom_orig[i], atom_orig[1 - i]); - atom_new = Rewriter::rewrite(atom_new); - atoms.push_back(atom_new); - } - } - } - else + if (bounds.add(lit, false)) { - atoms.push_back(atom_orig); + continue; } + init_assertions.insert(lit); + } - for (const Node& atom : atoms) + for (const auto& vb : bounds.get()) + { + const Bounds& b = vb.second; + if (!b.lower_bound.isNull()) { - // non-strict bounds only - if (atom.getKind() == GEQ || (!pol && atom.getKind() == GT)) - { - Node p = atom[0]; - Assert(atom[1].isConst()); - Rational bound = atom[1].getConst(); - if (!pol) - { - if (atom[0].getType().isInteger()) - { - // ~( p >= c ) ---> ( p <= c-1 ) - bound = bound - Rational(1); - } - } - unsigned bindex = pol ? 0 : 1; - bool setBound = true; - std::map::iterator itb = init_bounds[bindex].find(p); - if (itb != init_bounds[bindex].end()) - { - if (itb->second == bound) - { - setBound = atom_orig.getKind() == EQUAL; - } - else - { - setBound = pol ? itb->second < bound : itb->second > bound; - } - if (setBound) - { - // the bound is subsumed - init_assertions.erase(init_bounds_lit[bindex][p]); - } - } - if (setBound) - { - Trace("nl-ext-init") << (pol ? "Lower" : "Upper") << " bound for " - << p << " : " << bound << std::endl; - init_bounds[bindex][p] = bound; - init_bounds_lit[bindex][p] = lit; - } - } + init_assertions.insert(b.lower_bound); } - } - // for each bound that is the same, ensure we've inferred the equality - for (std::pair& ib : init_bounds[0]) - { - Node p = ib.first; - Node lit1 = init_bounds_lit[0][p]; - if (lit1.getKind() != EQUAL) + if (!b.upper_bound.isNull()) { - std::map::iterator itb = init_bounds[1].find(p); - if (itb != init_bounds[1].end()) - { - if (ib.second == itb->second) - { - Node eq = p.eqNode(nm->mkConst(ib.second)); - eq = Rewriter::rewrite(eq); - Node lit2 = init_bounds_lit[1][p]; - Assert(lit2.getKind() != EQUAL); - // use the equality instead, thus these are redundant - init_assertions.erase(lit1); - init_assertions.erase(lit2); - init_assertions.insert(eq); - } - } + init_assertions.insert(b.upper_bound); } } @@ -301,6 +227,7 @@ void NonlinearExtension::getAssertions(std::vector& assertions) auto iait = init_assertions.find(lit); if (iait != init_assertions.end()) { + Trace("nl-ext") << "Adding " << lit << std::endl; assertions.push_back(lit); init_assertions.erase(iait); } @@ -309,10 +236,12 @@ void NonlinearExtension::getAssertions(std::vector& assertions) // function by the code above. for (const Node& a : init_assertions) { + Trace("nl-ext") << "Adding " << a << std::endl; assertions.push_back(a); } - Trace("nl-ext") << "...keep " << assertions.size() << " / " << nassertions - << " assertions." << std::endl; + Trace("nl-ext") << "...keep " << assertions.size() << " / " + << d_containing.numAssertions() << " assertions." + << std::endl; } std::vector NonlinearExtension::checkModelEval( diff --git a/src/theory/arith/nl/nonlinear_extension.h b/src/theory/arith/nl/nonlinear_extension.h index 2f4586d78..bd3042231 100644 --- a/src/theory/arith/nl/nonlinear_extension.h +++ b/src/theory/arith/nl/nonlinear_extension.h @@ -249,6 +249,7 @@ class NonlinearExtension * and for establishing when we are able to answer "SAT". */ NlModel d_model; + /** The transcendental extension object * * This is the subsolver responsible for running the procedure for diff --git a/src/theory/arith/nl/poly_conversion.cpp b/src/theory/arith/nl/poly_conversion.cpp index a76a781c4..0e4e21b76 100644 --- a/src/theory/arith/nl/poly_conversion.cpp +++ b/src/theory/arith/nl/poly_conversion.cpp @@ -785,12 +785,12 @@ poly::IntervalAssignment getBounds(VariableMapper& vm, const BoundInference& bi) for (const auto& vb : bi.get()) { poly::Variable v = vm(vb.first); - poly::Value l = vb.second.lower.isNull() + poly::Value l = vb.second.lower_value.isNull() ? poly::Value::minus_infty() - : node_to_value(vb.second.lower, vb.first); - poly::Value u = vb.second.upper.isNull() + : node_to_value(vb.second.lower_value, vb.first); + poly::Value u = vb.second.upper_value.isNull() ? poly::Value::plus_infty() - : node_to_value(vb.second.upper, vb.first); + : node_to_value(vb.second.upper_value, vb.first); poly::Interval i(l, vb.second.lower_strict, u, vb.second.upper_strict); res.set(v, i); } -- 2.30.2