From cb7c3b3575facae1ef9fbc433adfdf7260b379cf Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Mon, 13 Dec 2021 09:09:21 -0800 Subject: [PATCH] Improve nonlinear solver (#7787) This PR does two things: we remove splitting on shared values we add variable elimination for the cad-based solver, exploiting equalities present in the input. --- src/CMakeLists.txt | 2 + src/options/arith_options.toml | 8 + src/smt/set_defaults.cpp | 2 + src/theory/arith/nl/cad/cdcac.cpp | 49 +++-- src/theory/arith/nl/cad/cdcac.h | 15 ++ src/theory/arith/nl/cad/lazard_evaluation.cpp | 48 +++-- src/theory/arith/nl/cad/lazard_evaluation.h | 5 + src/theory/arith/nl/cad_solver.cpp | 53 ++++- src/theory/arith/nl/cad_solver.h | 4 + src/theory/arith/nl/equality_substitution.cpp | 183 ++++++++++++++++++ src/theory/arith/nl/equality_substitution.h | 102 ++++++++++ src/theory/arith/nl/nonlinear_extension.cpp | 77 +------- src/theory/arith/nl/strategy.cpp | 5 +- src/theory/substitutions.cpp | 25 ++- src/theory/substitutions.h | 8 +- .../arith/issue5219-conflict-rewrite.smt2 | 2 +- test/regress/regress1/nl/cos1-tc.smt2 | 2 +- 17 files changed, 468 insertions(+), 122 deletions(-) create mode 100644 src/theory/arith/nl/equality_substitution.cpp create mode 100644 src/theory/arith/nl/equality_substitution.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 07f1495fe..1d57dfeb4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -422,6 +422,8 @@ libcvc5_add_sources( theory/arith/nl/cad/proof_generator.h theory/arith/nl/cad/variable_ordering.cpp theory/arith/nl/cad/variable_ordering.h + theory/arith/nl/equality_substitution.cpp + theory/arith/nl/equality_substitution.h theory/arith/nl/ext/constraint.cpp theory/arith/nl/ext/constraint.h theory/arith/nl/ext/factoring_check.cpp diff --git a/src/options/arith_options.toml b/src/options/arith_options.toml index e5f65684b..5e6796864 100644 --- a/src/options/arith_options.toml +++ b/src/options/arith_options.toml @@ -499,6 +499,14 @@ name = "Arithmetic Theory" default = "false" help = "whether to use the cylindrical algebraic coverings solver for non-linear arithmetic" +[[option]] + name = "nlCadVarElim" + category = "regular" + long = "nl-cad-var-elim" + type = "bool" + default = "false" + help = "whether to eliminate variables using equalities before going into the cylindrical algebraic coverings solver" + [[option]] name = "nlCadPrune" category = "regular" diff --git a/src/smt/set_defaults.cpp b/src/smt/set_defaults.cpp index 9c5a5a6b3..194290399 100644 --- a/src/smt/set_defaults.cpp +++ b/src/smt/set_defaults.cpp @@ -807,6 +807,7 @@ void SetDefaults::setDefaultsPost(const LogicInfo& logic, Options& opts) const if (!opts.arith.nlCad && !opts.arith.nlCadWasSetByUser) { opts.arith.nlCad = true; + opts.arith.nlCadVarElim = true; if (!opts.arith.nlExtWasSetByUser) { opts.arith.nlExt = options::NlExtMode::LIGHT; @@ -823,6 +824,7 @@ void SetDefaults::setDefaultsPost(const LogicInfo& logic, Options& opts) const if (!opts.arith.nlCad && !opts.arith.nlCadWasSetByUser) { opts.arith.nlCad = true; + opts.arith.nlCadVarElim = true; if (!opts.arith.nlExtWasSetByUser) { opts.arith.nlExt = options::NlExtMode::LIGHT; diff --git a/src/theory/arith/nl/cad/cdcac.cpp b/src/theory/arith/nl/cad/cdcac.cpp index 2fc77be1b..18ccf7aca 100644 --- a/src/theory/arith/nl/cad/cdcac.cpp +++ b/src/theory/arith/nl/cad/cdcac.cpp @@ -105,16 +105,7 @@ std::vector CDCAC::getUnsatIntervals(std::size_t cur_variable) { std::vector res; LazardEvaluation le; - if (options().arith.nlCadLifting - == options::NlCadLiftingMode::LAZARD) - { - for (size_t vid = 0; vid < cur_variable; ++vid) - { - const auto& val = d_assignment.get(d_variableOrdering[vid]); - le.add(d_variableOrdering[vid], val); - } - le.addFreeVariable(d_variableOrdering[cur_variable]); - } + prepareRootIsolation(le, cur_variable); for (const auto& c : d_constraints.getConstraints()) { const poly::Polynomial& p = std::get<0>(c); @@ -428,11 +419,17 @@ CACInterval CDCAC::intervalFromCharacterization( m.pushDownPolys(d, d_variableOrdering[cur_variable]); // Collect -oo, all roots, oo + + LazardEvaluation le; + prepareRootIsolation(le, cur_variable); std::vector roots; roots.emplace_back(poly::Value::minus_infty()); for (const auto& p : m) { - auto tmp = isolate_real_roots(p, d_assignment); + Trace("cdcac") << "Isolating real roots of " << p << " over " + << d_assignment << std::endl; + + auto tmp = isolateRealRoots(le, p); roots.insert(roots.end(), tmp.begin(), tmp.end()); } roots.emplace_back(poly::Value::plus_infty()); @@ -464,6 +461,8 @@ CACInterval CDCAC::intervalFromCharacterization( d_assignment.set(d_variableOrdering[cur_variable], lower); for (const auto& p : m) { + Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment + << std::endl; if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ)) { l.add(p, true); @@ -477,6 +476,8 @@ CACInterval CDCAC::intervalFromCharacterization( d_assignment.set(d_variableOrdering[cur_variable], upper); for (const auto& p : m) { + Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment + << std::endl; if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ)) { u.add(p, true); @@ -570,8 +571,10 @@ std::vector CDCAC::getUnsatCoverImpl(std::size_t curVariable, d_assignment.unset(d_variableOrdering[curVariable]); + Trace("cdcac") << "Building interval..." << std::endl; auto newInterval = intervalFromCharacterization(characterization, curVariable, sample); + Trace("cdcac") << "New interval: " << newInterval.d_interval << std::endl; newInterval.d_origins = collectConstraints(cov); intervals.emplace_back(newInterval); if (isProofEnabled()) @@ -730,6 +733,30 @@ void CDCAC::pruneRedundantIntervals(std::vector& intervals) } } +void CDCAC::prepareRootIsolation(LazardEvaluation& le, + size_t cur_variable) const +{ + if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD) + { + for (size_t vid = 0; vid < cur_variable; ++vid) + { + const auto& val = d_assignment.get(d_variableOrdering[vid]); + le.add(d_variableOrdering[vid], val); + } + le.addFreeVariable(d_variableOrdering[cur_variable]); + } +} + +std::vector CDCAC::isolateRealRoots( + LazardEvaluation& le, const poly::Polynomial& p) const +{ + if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD) + { + return le.isolateRealRoots(p); + } + return poly::isolate_real_roots(p, d_assignment); +} + } // namespace cad } // namespace nl } // namespace arith diff --git a/src/theory/arith/nl/cad/cdcac.h b/src/theory/arith/nl/cad/cdcac.h index 04b5cab24..8317c0813 100644 --- a/src/theory/arith/nl/cad/cdcac.h +++ b/src/theory/arith/nl/cad/cdcac.h @@ -29,6 +29,7 @@ #include "smt/env_obj.h" #include "theory/arith/nl/cad/cdcac_utils.h" #include "theory/arith/nl/cad/constraints.h" +#include "theory/arith/nl/cad/lazard_evaluation.h" #include "theory/arith/nl/cad/proof_generator.h" #include "theory/arith/nl/cad/variable_ordering.h" @@ -195,6 +196,20 @@ class CDCAC : protected EnvObj */ void pruneRedundantIntervals(std::vector& intervals); + /** + * Prepare the lazard evaluation object with the current assignment, if the + * lazard lifting is enabled. Otherwise, this function does nothing. + */ + void prepareRootIsolation(LazardEvaluation& le, size_t cur_variable) const; + + /** + * Isolates the real roots of the polynomial `p`. If the lazard lifting is + * enabled, this function uses `le.isolateRealRoots()`, otherwise uses the + * regular `poly::isolate_real_roots()`. + */ + std::vector isolateRealRoots(LazardEvaluation& le, + const poly::Polynomial& p) const; + /** * The current assignment. When the method terminates with SAT, it contains a * model for the input constraints. diff --git a/src/theory/arith/nl/cad/lazard_evaluation.cpp b/src/theory/arith/nl/cad/lazard_evaluation.cpp index aec0d46e3..032565d3d 100644 --- a/src/theory/arith/nl/cad/lazard_evaluation.cpp +++ b/src/theory/arith/nl/cad/lazard_evaluation.cpp @@ -821,22 +821,11 @@ std::vector LazardEvaluation::reducePolynomial( return {p}; } -/** - * Compute the infeasible regions of the given polynomial according to a sign - * condition. We first reduce the polynomial and isolate the real roots of every - * resulting polynomial. We store all roots (except for -infty, +infty and none) - * in a set. Then, we transform the set of roots into a list of infeasible - * regions by generating intervals between -infty and the first root, in between - * every two consecutive roots and between the last root and +infty. While doing - * this, we only keep those intervals that are actually infeasible for the - * original polynomial q over the partial assignment. Finally, we go over the - * intervals and aggregate consecutive intervals that connect. - */ -std::vector LazardEvaluation::infeasibleRegions( - const poly::Polynomial& q, poly::SignCondition sc) const +std::vector LazardEvaluation::isolateRealRoots( + const poly::Polynomial& q) const { poly::Assignment a; - std::set roots; + std::vector roots; // reduce q to a set of reduced polynomials p for (const auto& p : reducePolynomial(q)) { @@ -849,9 +838,28 @@ std::vector LazardEvaluation::infeasibleRegions( if (poly::is_minus_infinity(r)) continue; if (poly::is_none(r)) continue; if (poly::is_plus_infinity(r)) continue; - roots.insert(r); + roots.emplace_back(r); } } + std::sort(roots.begin(), roots.end()); + return roots; +} + +/** + * Compute the infeasible regions of the given polynomial according to a sign + * condition. We first reduce the polynomial and isolate the real roots of every + * resulting polynomial. We store all roots (except for -infty, +infty and none) + * in a set. Then, we transform the set of roots into a list of infeasible + * regions by generating intervals between -infty and the first root, in between + * every two consecutive roots and between the last root and +infty. While doing + * this, we only keep those intervals that are actually infeasible for the + * original polynomial q over the partial assignment. Finally, we go over the + * intervals and aggregate consecutive intervals that connect. + */ +std::vector LazardEvaluation::infeasibleRegions( + const poly::Polynomial& q, poly::SignCondition sc) const +{ + std::vector roots = isolateRealRoots(q); // generate all intervals // (-infty,root_0), [root_0], (root_0,root_1), ..., [root_m], (root_m,+infty) @@ -962,6 +970,16 @@ std::vector LazardEvaluation::reducePolynomial( { return {p}; } + +std::vector LazardEvaluation::isolateRealRoots( + const poly::Polynomial& q) const +{ + WarningOnce() + << "CAD::LazardEvaluation is disabled because CoCoA is not available. " + "Falling back to regular real root isolation." + << std::endl; + return poly::isolate_real_roots(q, d_state->d_assignment); +} std::vector LazardEvaluation::infeasibleRegions( const poly::Polynomial& q, poly::SignCondition sc) const { diff --git a/src/theory/arith/nl/cad/lazard_evaluation.h b/src/theory/arith/nl/cad/lazard_evaluation.h index 3bb971c4c..2afccb462 100644 --- a/src/theory/arith/nl/cad/lazard_evaluation.h +++ b/src/theory/arith/nl/cad/lazard_evaluation.h @@ -93,6 +93,11 @@ class LazardEvaluation std::vector reducePolynomial( const poly::Polynomial& q) const; + /** + * Isolates the real roots of the given polynomials. + */ + std::vector isolateRealRoots(const poly::Polynomial& q) const; + /** * Compute the infeasible regions of q under the given sign condition. * Uses reducePolynomial and then performs real root isolation on the diff --git a/src/theory/arith/nl/cad_solver.cpp b/src/theory/arith/nl/cad_solver.cpp index 721308a3d..f4582ac20 100644 --- a/src/theory/arith/nl/cad_solver.cpp +++ b/src/theory/arith/nl/cad_solver.cpp @@ -16,12 +16,14 @@ #include "theory/arith/nl/cad_solver.h" #include "expr/skolem_manager.h" +#include "options/arith_options.h" #include "smt/env.h" #include "theory/arith/inference_manager.h" #include "theory/arith/nl/cad/cdcac.h" #include "theory/arith/nl/nl_model.h" #include "theory/arith/nl/poly_conversion.h" #include "theory/inference_id.h" +#include "theory/theory.h" namespace cvc5 { namespace theory { @@ -36,7 +38,8 @@ CadSolver::CadSolver(Env& env, InferenceManager& im, NlModel& model) #endif d_foundSatisfiability(false), d_im(im), - d_model(model) + d_model(model), + d_eqsubs(env) { NodeManager* nm = NodeManager::currentNM(); SkolemManager* sm = nm->getSkolemManager(); @@ -65,11 +68,41 @@ void CadSolver::initLastCall(const std::vector& assertions) Trace("nl-cad") << " " << a << std::endl; } } - // store or process assertions - d_CAC.reset(); - for (const Node& a : assertions) + if (options().arith.nlCadVarElim) { - d_CAC.getConstraints().addConstraint(a); + d_eqsubs.reset(); + std::vector processed = d_eqsubs.eliminateEqualities(assertions); + if (d_eqsubs.hasConflict()) + { + Node lem = NodeManager::currentNM()->mkAnd(d_eqsubs.getConflict()).negate(); + d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, nullptr); + Trace("nl-cad") << "Found conflict: " << lem << std::endl; + return; + } + if (Trace.isOn("nl-cad")) + { + Trace("nl-cad") << "After simplifications" << std::endl; + Trace("nl-cad") << "* Assertions: " << std::endl; + for (const Node& a : processed) + { + Trace("nl-cad") << " " << a << std::endl; + } + } + d_CAC.reset(); + for (const Node& a : processed) + { + Assert(!a.isConst()); + d_CAC.getConstraints().addConstraint(a); + } + } + else + { + d_CAC.reset(); + for (const Node& a : assertions) + { + Assert(!a.isConst()); + d_CAC.getConstraints().addConstraint(a); + } } d_CAC.computeVariableOrdering(); d_CAC.retrieveInitialAssignment(d_model, d_ranVariable); @@ -84,6 +117,7 @@ void CadSolver::checkFull() { #ifdef CVC5_POLY_IMP if (d_CAC.getConstraints().getConstraints().empty()) { + d_foundSatisfiability = true; Trace("nl-cad") << "No constraints. Return." << std::endl; return; } @@ -101,6 +135,8 @@ void CadSolver::checkFull() Trace("nl-cad") << "Collected MIS: " << mis << std::endl; Assert(!mis.empty()) << "Infeasible subset can not be empty"; Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl; + d_eqsubs.postprocessConflict(mis); + Trace("nl-cad") << "After postprocessing: " << mis << std::endl; Node lem = NodeManager::currentNM()->mkAnd(mis).negate(); ProofGenerator* proof = d_CAC.closeProof(mis); d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, proof); @@ -170,10 +206,15 @@ bool CadSolver::constructModelIfAvailable(std::vector& assertions) return false; } bool foundNonVariable = false; + for (const auto& sub: d_eqsubs.getSubstitutions()) + { + d_model.addSubstitution(sub.first, sub.second); + Trace("nl-cad") << "-> " << sub.first << " = " << sub.second << std::endl; + } for (const auto& v : d_CAC.getVariableOrdering()) { Node variable = d_CAC.getConstraints().varMapper()(v); - if (!variable.isVar()) + if (!Theory::isLeafOf(variable, TheoryId::THEORY_ARITH)) { Trace("nl-cad") << "Not a variable: " << variable << std::endl; foundNonVariable = true; diff --git a/src/theory/arith/nl/cad_solver.h b/src/theory/arith/nl/cad_solver.h index bedffcaa9..73d09378b 100644 --- a/src/theory/arith/nl/cad_solver.h +++ b/src/theory/arith/nl/cad_solver.h @@ -23,6 +23,7 @@ #include "smt/env_obj.h" #include "theory/arith/nl/cad/cdcac.h" #include "theory/arith/nl/cad/proof_checker.h" +#include "theory/arith/nl/equality_substitution.h" namespace cvc5 { @@ -104,6 +105,9 @@ class CadSolver: protected EnvObj InferenceManager& d_im; /** Reference to the non-linear model object */ NlModel& d_model; + /** Utility to eliminate variables from simple equalities before going into + * the actual coverings solver */ + EqualitySubstitution d_eqsubs; }; /* class CadSolver */ } // namespace nl diff --git a/src/theory/arith/nl/equality_substitution.cpp b/src/theory/arith/nl/equality_substitution.cpp new file mode 100644 index 000000000..9b3a79cd4 --- /dev/null +++ b/src/theory/arith/nl/equality_substitution.cpp @@ -0,0 +1,183 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer, Andrew Reynolds, Andres Noetzli + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Implementation of new non-linear solver. + */ + +#include "theory/arith/nl/equality_substitution.h" + +#include "smt/env.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace nl { + +EqualitySubstitution::EqualitySubstitution(Env& env) + : EnvObj(env), d_substitutions(std::make_unique()) +{ +} +void EqualitySubstitution::reset() +{ + d_substitutions = std::make_unique(); + d_conflict.clear(); + d_conflictMap.clear(); + d_trackOrigin.clear(); +} + +std::vector EqualitySubstitution::eliminateEqualities( + const std::vector& assertions) +{ + Trace("nl-eqs") << "Input:" << std::endl; + for (const auto& a : assertions) + { + Trace("nl-eqs") << "\t" << a << std::endl; + } + std::set tracker; + std::vector asserts = assertions; + std::vector next; + + size_t last_size = 0; + while (asserts.size() != last_size) + { + last_size = asserts.size(); + // collect all eliminations from original into d_substitutions + for (const auto& orig : asserts) + { + if (orig.getKind() != Kind::EQUAL) continue; + tracker.clear(); + d_substitutions->invalidateCache(); + Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker); + Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o + << std::endl; + if (o.getKind() != Kind::EQUAL) continue; + Assert(o.getNumChildren() == 2); + for (size_t i = 0; i < 2; ++i) + { + const auto& l = o[i]; + const auto& r = o[1 - i]; + if (l.isConst()) continue; + if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue; + if (d_substitutions->hasSubstitution(l)) continue; + if (expr::hasSubterm(r, l, true)) continue; + Trace("nl-eqs") << "Found substitution " << l << " -> " << r + << std::endl + << " from " << o << " / " << orig << std::endl; + d_substitutions->addSubstitution(l, r); + d_trackOrigin.emplace(l, o); + if (o != orig) + { + addToConflictMap(o, orig, tracker); + } + break; + } + } + + // simplify with subs from original into next + next.clear(); + for (const auto& a : asserts) + { + tracker.clear(); + d_substitutions->invalidateCache(); + Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker); + if (simp.isConst()) + { + if (simp.getConst()) + { + continue; + } + Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; + for (TNode t : tracker) + { + Trace("nl-eqs") << "Tracker has " << t << std::endl; + auto toit = d_trackOrigin.find(t); + Assert(toit != d_trackOrigin.end()); + d_conflict.emplace_back(toit->second); + } + d_conflict.emplace_back(a); + postprocessConflict(d_conflict); + Trace("nl-eqs") << "Direct conflict: " << d_conflict << std::endl; + Trace("nl-eqs") << std::endl + << d_conflict.size() << " vs " + << std::distance(d_substitutions->begin(), + d_substitutions->end()) + << std::endl + << std::endl; + return {}; + } + if (simp != a) + { + Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl; + addToConflictMap(simp, a, tracker); + } + next.emplace_back(simp); + } + asserts = std::move(next); + } + d_conflict.clear(); + return asserts; +} +void EqualitySubstitution::postprocessConflict( + std::vector& conflict) const +{ + Trace("nl-eqs") << "Postprocessing " << conflict << std::endl; + std::set result; + for (const auto& c : conflict) + { + auto it = d_conflictMap.find(c); + if (it == d_conflictMap.end()) + { + result.insert(c); + } + else + { + Trace("nl-eqs") << "Origin of " << c << ": " << it->second << std::endl; + result.insert(it->second.begin(), it->second.end()); + } + } + conflict.clear(); + conflict.insert(conflict.end(), result.begin(), result.end()); + Trace("nl-eqs") << "-> " << conflict << std::endl; +} +void EqualitySubstitution::insertOrigins(std::set& dest, + const Node& n) const +{ + auto it = d_conflictMap.find(n); + if (it == d_conflictMap.end()) + { + dest.insert(n); + } + else + { + dest.insert(it->second.begin(), it->second.end()); + } +} +void EqualitySubstitution::addToConflictMap(const Node& n, + const Node& orig, + const std::set& tracker) +{ + std::set origins; + insertOrigins(origins, orig); + for (const auto& t : tracker) + { + auto tit = d_trackOrigin.find(t); + Assert(tit != d_trackOrigin.end()); + insertOrigins(origins, tit->second); + } + Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl; + d_conflictMap.emplace(n, std::vector(origins.begin(), origins.end())); +} + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/arith/nl/equality_substitution.h b/src/theory/arith/nl/equality_substitution.h new file mode 100644 index 000000000..b095af8df --- /dev/null +++ b/src/theory/arith/nl/equality_substitution.h @@ -0,0 +1,102 @@ +/****************************************************************************** + * Top contributors (to current version): + * Gereon Kremer + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * CAD-based solver based on https://arxiv.org/pdf/2003.05633.pdf. + */ + +#ifndef CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H +#define CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H + +#include + +#include "context/context.h" +#include "expr/node.h" +#include "expr/node_algorithm.h" +#include "smt/env_obj.h" +#include "theory/substitutions.h" +#include "theory/theory.h" + +namespace cvc5 { +namespace theory { +namespace arith { +namespace nl { + +/** + * This class is a general utility to eliminate variables from a set of + * assertions. + */ +class EqualitySubstitution : protected EnvObj +{ + public: + EqualitySubstitution(Env& env); + /** Reset this object */ + void reset(); + + /** + * Eliminate variables using equalities from the set of assertions. + * Returns a set of assertions where some variables may have been eliminated. + * Substitutions for the eliminated variables can be obtained from + * getSubstitutions(). + */ + std::vector eliminateEqualities(const std::vector& assertions); + /** + * Can be called after eliminateEqualities(). Returns the substitutions that + * were found and eliminated. + */ + const SubstitutionMap& getSubstitutions() const { return *d_substitutions; } + /** + * Can be called after eliminateEqualities(). Checks whether a direct conflict + * was found, that is an assertion simplified to false during + * eliminateEqualities(). + */ + bool hasConflict() const { return !d_conflict.empty(); } + /** + * Return the conflict found in eliminateEqualities() as a set of assertions + * that is a subset of the input assertions provided to eliminateEqualities(). + */ + const std::vector& getConflict() const { return d_conflict; } + /** + * Postprocess a conflict found in the result of eliminateEqualities. + * Replaces assertions within the conflict by their origins, i.e. the input + * assertions and the assertions that gave rise to the substitutions being + * used. + */ + void postprocessConflict(std::vector& conflict) const; + + private: + /** Utility method for addToConflictMap. Checks for n in d_conflictMap */ + void insertOrigins(std::set& dest, const Node& n) const; + /** Add n -> { orig, *tracker } to the conflict map. The tracked nodes are + * first resolved using d_trackOrigin, and everything is run through + * insertOrigins to make sure that all origins are input assertions. */ + void addToConflictMap(const Node& n, + const Node& orig, + const std::set& tracker); + + // The SubstitutionMap + std::unique_ptr d_substitutions; + // conflicting assertions, if a conflict was found + std::vector d_conflict; + // Maps a simplified assertion to the original assertion + set of original + // assertions used for substitutions + std::map> d_conflictMap; + // Maps substituted terms (what will end up in the tracker) to the equality + // from which the substitution was derived. + std::map d_trackOrigin; +}; + +} // namespace nl +} // namespace arith +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H */ diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index 77bb164a9..3f60f8596 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -353,45 +353,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set& termS } // compute whether shared terms have correct values - unsigned num_shared_wrong_value = 0; - std::vector shared_term_value_splits; - // must ensure that shared terms are equal to their concrete value - Trace("nl-ext-mv") << "Shared terms : " << std::endl; - for (context::CDList::const_iterator its = - d_containing.shared_terms_begin(); - its != d_containing.shared_terms_end(); - ++its) - { - TNode shared_term = *its; - // compute its value in the model, and its evaluation in the model - Node stv0 = d_model.computeConcreteModelValue(shared_term); - Node stv1 = d_model.computeAbstractModelValue(shared_term); - d_model.printModelValue("nl-ext-mv", shared_term); - if (stv0 != stv1) - { - num_shared_wrong_value++; - Trace("nl-ext-mv") << "Bad shared term value : " << shared_term - << std::endl; - if (shared_term != stv0) - { - // split on the value, this is non-terminating in general, TODO : - // improve this - Node eq = shared_term.eqNode(stv0); - shared_term_value_splits.push_back(eq); - } - else - { - // this can happen for transcendental functions - // the problem is that we cannot evaluate transcendental functions - // (they don't have a rewriter that returns constants) - // thus, the actual value in their model can be themselves, hence we - // have no reference point to rule out the current model. In this - // case, we may set incomplete below. - } - } - } - Trace("nl-ext-debug") << " " << num_shared_wrong_value - << " shared terms with wrong model value." << std::endl; bool needsRecheck; do { @@ -402,9 +363,9 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set& termS int complete_status = 1; // We require a check either if an assertion is false or a shared term has // a wrong value - if (!false_asserts.empty() || num_shared_wrong_value > 0) + if (!false_asserts.empty()) { - complete_status = num_shared_wrong_value > 0 ? -1 : 0; + complete_status = 0; runStrategy(Theory::Effort::EFFORT_FULL, assertions, false_asserts, xts); if (d_im.hasSentLemma() || d_im.hasPendingLemma()) { @@ -446,40 +407,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set& termS << std::endl; return Result::Sat::UNSAT; } - // resort to splitting on shared terms with their model value - // if we did not add any lemmas - if (num_shared_wrong_value > 0) - { - complete_status = -1; - if (!shared_term_value_splits.empty()) - { - for (const Node& eq : shared_term_value_splits) - { - Node req = rewrite(eq); - Node literal = d_containing.getValuation().ensureLiteral(req); - d_containing.getOutputChannel().requirePhase(literal, true); - Trace("nl-ext-debug") << "Split on : " << literal << std::endl; - Node split = literal.orNode(literal.negate()); - d_im.addPendingLemma(split, - InferenceId::ARITH_NL_SHARED_TERM_VALUE_SPLIT, - nullptr, - true); - } - if (d_im.hasWaitingLemma()) - { - d_im.flushWaitingLemmas(); - Trace("nl-ext") << "...added " << d_im.numPendingLemmas() - << " shared term value split lemmas." << std::endl; - return Result::Sat::UNSAT; - } - } - else - { - // this can happen if we are trying to do theory combination with - // trancendental functions - // since their model value cannot even be computed exactly - } - } // we are incomplete if (options().arith.nlExt == options::NlExtMode::FULL diff --git a/src/theory/arith/nl/strategy.cpp b/src/theory/arith/nl/strategy.cpp index b33e45129..a14841f67 100644 --- a/src/theory/arith/nl/strategy.cpp +++ b/src/theory/arith/nl/strategy.cpp @@ -172,10 +172,7 @@ void Strategy::initializeStrategy(const Options& options) one << InferStep::POW2_FULL << InferStep::BREAK; if (options.arith.nlCad) { - one << InferStep::CAD_INIT; - } - if (options.arith.nlCad) - { + one << InferStep::CAD_INIT << InferStep::BREAK; one << InferStep::CAD_FULL << InferStep::BREAK; } diff --git a/src/theory/substitutions.cpp b/src/theory/substitutions.cpp index f91094481..e49563046 100644 --- a/src/theory/substitutions.cpp +++ b/src/theory/substitutions.cpp @@ -39,7 +39,7 @@ struct substitution_stack_element { } };/* struct substitution_stack_element */ -Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) { +Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set* tracker) { Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << ")" << endl; @@ -70,10 +70,17 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) { if (find2 != d_substitutions.end()) { Node rhs = (*find2).second; Assert(rhs != current); - internalSubstitute(rhs, cache); - d_substitutions[current] = cache[rhs]; + internalSubstitute(rhs, cache, tracker); + if (tracker == nullptr) + { + d_substitutions[current] = cache[rhs]; + } cache[current] = cache[rhs]; toVisit.pop_back(); + if (tracker != nullptr) + { + tracker->insert(current); + } continue; } @@ -101,10 +108,14 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) { if (find2 != d_substitutions.end()) { Node rhs = (*find2).second; Assert(rhs != result); - internalSubstitute(rhs, cache); + internalSubstitute(rhs, cache, tracker); d_substitutions[result] = cache[rhs]; cache[result] = cache[rhs]; result = cache[rhs]; + if (tracker != nullptr) + { + tracker->insert(result); + } } } } @@ -184,8 +195,8 @@ void SubstitutionMap::addSubstitutions(SubstitutionMap& subMap, bool invalidateC } } -Node SubstitutionMap::apply(TNode t, Rewriter* r) -{ +Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set* tracker) { + Debug("substitution") << "SubstitutionMap::apply(" << t << ")" << endl; // Setup the cache @@ -196,7 +207,7 @@ Node SubstitutionMap::apply(TNode t, Rewriter* r) } // Perform the substitution - Node result = internalSubstitute(t, d_substitutionCache); + Node result = internalSubstitute(t, d_substitutionCache, tracker); Debug("substitution") << "SubstitutionMap::apply(" << t << ") => " << result << endl; if (r != nullptr) diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index 7a3afcb11..2154c7fd5 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -65,7 +65,7 @@ class SubstitutionMap bool d_cacheInvalidated; /** Internal method that performs substitution */ - Node internalSubstitute(TNode t, NodeCache& cache); + Node internalSubstitute(TNode t, NodeCache& cache, std::set* tracker); /** Helper class to invalidate cache on user pop */ class CacheInvalidator : public context::ContextNotifyObj @@ -130,7 +130,7 @@ class SubstitutionMap * Apply the substitutions to the node, optionally rewrite if a non-null * Rewriter pointer is passed. */ - Node apply(TNode t, Rewriter* r = nullptr); + Node apply(TNode t, Rewriter* r = nullptr, std::set* tracker = nullptr); /** * Apply the substitutions to the node. @@ -155,6 +155,10 @@ class SubstitutionMap */ void print(std::ostream& out) const; + void invalidateCache() { + d_cacheInvalidated = true; + } + }; /* class SubstitutionMap */ inline std::ostream& operator << (std::ostream& out, const SubstitutionMap& subst) { diff --git a/test/regress/regress0/arith/issue5219-conflict-rewrite.smt2 b/test/regress/regress0/arith/issue5219-conflict-rewrite.smt2 index ccb50c55d..b49287c30 100644 --- a/test/regress/regress0/arith/issue5219-conflict-rewrite.smt2 +++ b/test/regress/regress0/arith/issue5219-conflict-rewrite.smt2 @@ -1,6 +1,6 @@ ; REQUIRES: poly ; COMMAND-LINE: --theoryof-mode=term --nl-icp -; EXPECT: unknown +; EXPECT: sat (set-logic QF_NRA) (set-option :check-proofs true) (declare-fun x () Real) diff --git a/test/regress/regress1/nl/cos1-tc.smt2 b/test/regress/regress1/nl/cos1-tc.smt2 index bedc0209b..ba49f23fe 100644 --- a/test/regress/regress1/nl/cos1-tc.smt2 +++ b/test/regress/regress1/nl/cos1-tc.smt2 @@ -1,5 +1,5 @@ ; COMMAND-LINE: --nl-ext=full --no-nl-ext-tf-tplanes --no-nl-ext-inc-prec -; EXPECT: unknown +; EXPECT: sat (set-logic UFNRAT) (declare-fun f (Real) Real) -- 2.30.2