Improve nonlinear solver (#7787)
authorGereon Kremer <gkremer@stanford.edu>
Mon, 13 Dec 2021 17:09:21 +0000 (09:09 -0800)
committerGitHub <noreply@github.com>
Mon, 13 Dec 2021 17:09:21 +0000 (17:09 +0000)
This PR does two things:

we remove splitting on shared values
we add variable elimination for the cad-based solver, exploiting equalities present in the input.

17 files changed:
src/CMakeLists.txt
src/options/arith_options.toml
src/smt/set_defaults.cpp
src/theory/arith/nl/cad/cdcac.cpp
src/theory/arith/nl/cad/cdcac.h
src/theory/arith/nl/cad/lazard_evaluation.cpp
src/theory/arith/nl/cad/lazard_evaluation.h
src/theory/arith/nl/cad_solver.cpp
src/theory/arith/nl/cad_solver.h
src/theory/arith/nl/equality_substitution.cpp [new file with mode: 0644]
src/theory/arith/nl/equality_substitution.h [new file with mode: 0644]
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/strategy.cpp
src/theory/substitutions.cpp
src/theory/substitutions.h
test/regress/regress0/arith/issue5219-conflict-rewrite.smt2
test/regress/regress1/nl/cos1-tc.smt2

index 07f1495fe1b711cb111cd92b6e6c143e8cc54214..1d57dfeb48dc9c0bb4f45704cc4cb03e43e7bf0e 100644 (file)
@@ -422,6 +422,8 @@ libcvc5_add_sources(
   theory/arith/nl/cad/proof_generator.h
   theory/arith/nl/cad/variable_ordering.cpp
   theory/arith/nl/cad/variable_ordering.h
+  theory/arith/nl/equality_substitution.cpp
+  theory/arith/nl/equality_substitution.h
   theory/arith/nl/ext/constraint.cpp
   theory/arith/nl/ext/constraint.h
   theory/arith/nl/ext/factoring_check.cpp
index e5f65684b8a2c51487228ee973e58b8a5a947ecd..5e6796864d5d9301e739188a224cbc5c49e149e2 100644 (file)
@@ -499,6 +499,14 @@ name   = "Arithmetic Theory"
   default    = "false"
   help       = "whether to use the cylindrical algebraic coverings solver for non-linear arithmetic"
 
+[[option]]
+  name       = "nlCadVarElim"
+  category   = "regular"
+  long       = "nl-cad-var-elim"
+  type       = "bool"
+  default    = "false"
+  help       = "whether to eliminate variables using equalities before going into the cylindrical algebraic coverings solver"
+
 [[option]]
   name       = "nlCadPrune"
   category   = "regular"
index 9c5a5a6b3474e9d11b26469ec38d3a90c0d7b17a..1942903994908dab72ddb145f3eb4241d08d564e 100644 (file)
@@ -807,6 +807,7 @@ void SetDefaults::setDefaultsPost(const LogicInfo& logic, Options& opts) const
     if (!opts.arith.nlCad && !opts.arith.nlCadWasSetByUser)
     {
       opts.arith.nlCad = true;
+      opts.arith.nlCadVarElim = true;
       if (!opts.arith.nlExtWasSetByUser)
       {
         opts.arith.nlExt = options::NlExtMode::LIGHT;
@@ -823,6 +824,7 @@ void SetDefaults::setDefaultsPost(const LogicInfo& logic, Options& opts) const
     if (!opts.arith.nlCad && !opts.arith.nlCadWasSetByUser)
     {
       opts.arith.nlCad = true;
+      opts.arith.nlCadVarElim = true;
       if (!opts.arith.nlExtWasSetByUser)
       {
         opts.arith.nlExt = options::NlExtMode::LIGHT;
index 2fc77be1b655f3c3738846014b1696ffb28a2247..18ccf7aca71e3274f1552e8510848c527d359e92 100644 (file)
@@ -105,16 +105,7 @@ std::vector<CACInterval> CDCAC::getUnsatIntervals(std::size_t cur_variable)
 {
   std::vector<CACInterval> res;
   LazardEvaluation le;
-  if (options().arith.nlCadLifting
-      == options::NlCadLiftingMode::LAZARD)
-  {
-    for (size_t vid = 0; vid < cur_variable; ++vid)
-    {
-      const auto& val = d_assignment.get(d_variableOrdering[vid]);
-      le.add(d_variableOrdering[vid], val);
-    }
-    le.addFreeVariable(d_variableOrdering[cur_variable]);
-  }
+  prepareRootIsolation(le, cur_variable);
   for (const auto& c : d_constraints.getConstraints())
   {
     const poly::Polynomial& p = std::get<0>(c);
@@ -428,11 +419,17 @@ CACInterval CDCAC::intervalFromCharacterization(
   m.pushDownPolys(d, d_variableOrdering[cur_variable]);
 
   // Collect -oo, all roots, oo
+
+  LazardEvaluation le;
+  prepareRootIsolation(le, cur_variable);
   std::vector<poly::Value> roots;
   roots.emplace_back(poly::Value::minus_infty());
   for (const auto& p : m)
   {
-    auto tmp = isolate_real_roots(p, d_assignment);
+    Trace("cdcac") << "Isolating real roots of " << p << " over "
+                   << d_assignment << std::endl;
+
+    auto tmp = isolateRealRoots(le, p);
     roots.insert(roots.end(), tmp.begin(), tmp.end());
   }
   roots.emplace_back(poly::Value::plus_infty());
@@ -464,6 +461,8 @@ CACInterval CDCAC::intervalFromCharacterization(
     d_assignment.set(d_variableOrdering[cur_variable], lower);
     for (const auto& p : m)
     {
+      Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment
+                     << std::endl;
       if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ))
       {
         l.add(p, true);
@@ -477,6 +476,8 @@ CACInterval CDCAC::intervalFromCharacterization(
     d_assignment.set(d_variableOrdering[cur_variable], upper);
     for (const auto& p : m)
     {
+      Trace("cdcac") << "Evaluating " << p << " = 0 over " << d_assignment
+                     << std::endl;
       if (evaluate_constraint(p, d_assignment, poly::SignCondition::EQ))
       {
         u.add(p, true);
@@ -570,8 +571,10 @@ std::vector<CACInterval> CDCAC::getUnsatCoverImpl(std::size_t curVariable,
 
     d_assignment.unset(d_variableOrdering[curVariable]);
 
+    Trace("cdcac") << "Building interval..." << std::endl;
     auto newInterval =
         intervalFromCharacterization(characterization, curVariable, sample);
+    Trace("cdcac") << "New interval: " << newInterval.d_interval << std::endl;
     newInterval.d_origins = collectConstraints(cov);
     intervals.emplace_back(newInterval);
     if (isProofEnabled())
@@ -730,6 +733,30 @@ void CDCAC::pruneRedundantIntervals(std::vector<CACInterval>& intervals)
   }
 }
 
+void CDCAC::prepareRootIsolation(LazardEvaluation& le,
+                                 size_t cur_variable) const
+{
+  if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD)
+  {
+    for (size_t vid = 0; vid < cur_variable; ++vid)
+    {
+      const auto& val = d_assignment.get(d_variableOrdering[vid]);
+      le.add(d_variableOrdering[vid], val);
+    }
+    le.addFreeVariable(d_variableOrdering[cur_variable]);
+  }
+}
+
+std::vector<poly::Value> CDCAC::isolateRealRoots(
+    LazardEvaluation& le, const poly::Polynomial& p) const
+{
+  if (options().arith.nlCadLifting == options::NlCadLiftingMode::LAZARD)
+  {
+    return le.isolateRealRoots(p);
+  }
+  return poly::isolate_real_roots(p, d_assignment);
+}
+
 }  // namespace cad
 }  // namespace nl
 }  // namespace arith
index 04b5cab24422aaf16a1199db9c86d56da05ffb28..8317c08137b24b5c68d4f243f19c08a542fff638 100644 (file)
@@ -29,6 +29,7 @@
 #include "smt/env_obj.h"
 #include "theory/arith/nl/cad/cdcac_utils.h"
 #include "theory/arith/nl/cad/constraints.h"
+#include "theory/arith/nl/cad/lazard_evaluation.h"
 #include "theory/arith/nl/cad/proof_generator.h"
 #include "theory/arith/nl/cad/variable_ordering.h"
 
@@ -195,6 +196,20 @@ class CDCAC : protected EnvObj
    */
   void pruneRedundantIntervals(std::vector<CACInterval>& intervals);
 
+  /**
+   * Prepare the lazard evaluation object with the current assignment, if the
+   * lazard lifting is enabled. Otherwise, this function does nothing.
+   */
+  void prepareRootIsolation(LazardEvaluation& le, size_t cur_variable) const;
+
+  /**
+   * Isolates the real roots of the polynomial `p`. If the lazard lifting is
+   * enabled, this function uses `le.isolateRealRoots()`, otherwise uses the
+   * regular `poly::isolate_real_roots()`.
+   */
+  std::vector<poly::Value> isolateRealRoots(LazardEvaluation& le,
+                                            const poly::Polynomial& p) const;
+
   /**
    * The current assignment. When the method terminates with SAT, it contains a
    * model for the input constraints.
index aec0d46e3ef781e432d39304b123a429828d0ebe..032565d3d8c975dd121954d1e95e01e1aa2e8c80 100644 (file)
@@ -821,22 +821,11 @@ std::vector<poly::Polynomial> LazardEvaluation::reducePolynomial(
   return {p};
 }
 
-/**
- * Compute the infeasible regions of the given polynomial according to a sign
- * condition. We first reduce the polynomial and isolate the real roots of every
- * resulting polynomial. We store all roots (except for -infty, +infty and none)
- * in a set. Then, we transform the set of roots into a list of infeasible
- * regions by generating intervals between -infty and the first root, in between
- * every two consecutive roots and between the last root and +infty. While doing
- * this, we only keep those intervals that are actually infeasible for the
- * original polynomial q over the partial assignment. Finally, we go over the
- * intervals and aggregate consecutive intervals that connect.
- */
-std::vector<poly::Interval> LazardEvaluation::infeasibleRegions(
-    const poly::Polynomial& q, poly::SignCondition sc) const
+std::vector<poly::Value> LazardEvaluation::isolateRealRoots(
+    const poly::Polynomial& q) const
 {
   poly::Assignment a;
-  std::set<poly::Value> roots;
+  std::vector<poly::Value> roots;
   // reduce q to a set of reduced polynomials p
   for (const auto& p : reducePolynomial(q))
   {
@@ -849,9 +838,28 @@ std::vector<poly::Interval> LazardEvaluation::infeasibleRegions(
       if (poly::is_minus_infinity(r)) continue;
       if (poly::is_none(r)) continue;
       if (poly::is_plus_infinity(r)) continue;
-      roots.insert(r);
+      roots.emplace_back(r);
     }
   }
+  std::sort(roots.begin(), roots.end());
+  return roots;
+}
+
+/**
+ * Compute the infeasible regions of the given polynomial according to a sign
+ * condition. We first reduce the polynomial and isolate the real roots of every
+ * resulting polynomial. We store all roots (except for -infty, +infty and none)
+ * in a set. Then, we transform the set of roots into a list of infeasible
+ * regions by generating intervals between -infty and the first root, in between
+ * every two consecutive roots and between the last root and +infty. While doing
+ * this, we only keep those intervals that are actually infeasible for the
+ * original polynomial q over the partial assignment. Finally, we go over the
+ * intervals and aggregate consecutive intervals that connect.
+ */
+std::vector<poly::Interval> LazardEvaluation::infeasibleRegions(
+    const poly::Polynomial& q, poly::SignCondition sc) const
+{
+  std::vector<poly::Value> roots = isolateRealRoots(q);
 
   // generate all intervals
   // (-infty,root_0), [root_0], (root_0,root_1), ..., [root_m], (root_m,+infty)
@@ -962,6 +970,16 @@ std::vector<poly::Polynomial> LazardEvaluation::reducePolynomial(
 {
   return {p};
 }
+
+std::vector<poly::Value> LazardEvaluation::isolateRealRoots(
+    const poly::Polynomial& q) const
+{
+  WarningOnce()
+      << "CAD::LazardEvaluation is disabled because CoCoA is not available. "
+         "Falling back to regular real root isolation."
+      << std::endl;
+  return poly::isolate_real_roots(q, d_state->d_assignment);
+}
 std::vector<poly::Interval> LazardEvaluation::infeasibleRegions(
     const poly::Polynomial& q, poly::SignCondition sc) const
 {
index 3bb971c4c57aab455d5656dc8a9071eecf0e6ed9..2afccb462a8717cd027f21ee85e5f06c0b44e1ad 100644 (file)
@@ -93,6 +93,11 @@ class LazardEvaluation
   std::vector<poly::Polynomial> reducePolynomial(
       const poly::Polynomial& q) const;
 
+  /**
+   * Isolates the real roots of the given polynomials.
+   */
+  std::vector<poly::Value> isolateRealRoots(const poly::Polynomial& q) const;
+
   /**
    * Compute the infeasible regions of q under the given sign condition.
    * Uses reducePolynomial and then performs real root isolation on the
index 721308a3d34e9556aa58e305d822255248b7a12b..f4582ac2017d3a7b6cac8e43944cddfe65e74400 100644 (file)
 #include "theory/arith/nl/cad_solver.h"
 
 #include "expr/skolem_manager.h"
+#include "options/arith_options.h"
 #include "smt/env.h"
 #include "theory/arith/inference_manager.h"
 #include "theory/arith/nl/cad/cdcac.h"
 #include "theory/arith/nl/nl_model.h"
 #include "theory/arith/nl/poly_conversion.h"
 #include "theory/inference_id.h"
+#include "theory/theory.h"
 
 namespace cvc5 {
 namespace theory {
@@ -36,7 +38,8 @@ CadSolver::CadSolver(Env& env, InferenceManager& im, NlModel& model)
 #endif
       d_foundSatisfiability(false),
       d_im(im),
-      d_model(model)
+      d_model(model),
+      d_eqsubs(env)
 {
   NodeManager* nm = NodeManager::currentNM();
   SkolemManager* sm = nm->getSkolemManager();
@@ -65,11 +68,41 @@ void CadSolver::initLastCall(const std::vector<Node>& assertions)
       Trace("nl-cad") << "  " << a << std::endl;
     }
   }
-  // store or process assertions
-  d_CAC.reset();
-  for (const Node& a : assertions)
+  if (options().arith.nlCadVarElim)
   {
-    d_CAC.getConstraints().addConstraint(a);
+    d_eqsubs.reset();
+    std::vector<Node> processed = d_eqsubs.eliminateEqualities(assertions);
+    if (d_eqsubs.hasConflict())
+    {
+        Node lem = NodeManager::currentNM()->mkAnd(d_eqsubs.getConflict()).negate();
+        d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, nullptr);
+        Trace("nl-cad") << "Found conflict: " << lem << std::endl;
+        return;
+    }
+    if (Trace.isOn("nl-cad"))
+    {
+      Trace("nl-cad") << "After simplifications" << std::endl;
+      Trace("nl-cad") << "* Assertions: " << std::endl;
+      for (const Node& a : processed)
+      {
+        Trace("nl-cad") << "  " << a << std::endl;
+      }
+    }
+    d_CAC.reset();
+    for (const Node& a : processed)
+    {
+      Assert(!a.isConst());
+      d_CAC.getConstraints().addConstraint(a);
+    }
+  }
+  else
+  {
+    d_CAC.reset();
+    for (const Node& a : assertions)
+    {
+      Assert(!a.isConst());
+      d_CAC.getConstraints().addConstraint(a);
+    }
   }
   d_CAC.computeVariableOrdering();
   d_CAC.retrieveInitialAssignment(d_model, d_ranVariable);
@@ -84,6 +117,7 @@ void CadSolver::checkFull()
 {
 #ifdef CVC5_POLY_IMP
   if (d_CAC.getConstraints().getConstraints().empty()) {
+    d_foundSatisfiability = true;
     Trace("nl-cad") << "No constraints. Return." << std::endl;
     return;
   }
@@ -101,6 +135,8 @@ void CadSolver::checkFull()
     Trace("nl-cad") << "Collected MIS: " << mis << std::endl;
     Assert(!mis.empty()) << "Infeasible subset can not be empty";
     Trace("nl-cad") << "UNSAT with MIS: " << mis << std::endl;
+    d_eqsubs.postprocessConflict(mis);
+    Trace("nl-cad") << "After postprocessing: " << mis << std::endl;
     Node lem = NodeManager::currentNM()->mkAnd(mis).negate();
     ProofGenerator* proof = d_CAC.closeProof(mis);
     d_im.addPendingLemma(lem, InferenceId::ARITH_NL_CAD_CONFLICT, proof);
@@ -170,10 +206,15 @@ bool CadSolver::constructModelIfAvailable(std::vector<Node>& assertions)
     return false;
   }
   bool foundNonVariable = false;
+  for (const auto& sub: d_eqsubs.getSubstitutions())
+  {
+    d_model.addSubstitution(sub.first, sub.second);
+    Trace("nl-cad") << "-> " << sub.first << " = " << sub.second << std::endl;
+  }
   for (const auto& v : d_CAC.getVariableOrdering())
   {
     Node variable = d_CAC.getConstraints().varMapper()(v);
-    if (!variable.isVar())
+    if (!Theory::isLeafOf(variable, TheoryId::THEORY_ARITH))
     {
       Trace("nl-cad") << "Not a variable: " << variable << std::endl;
       foundNonVariable = true;
index bedffcaa9566e7943482a18b2c55600ed5de95e4..73d09378bd0e5ad0e8f3a782d9196edc9327c35b 100644 (file)
@@ -23,6 +23,7 @@
 #include "smt/env_obj.h"
 #include "theory/arith/nl/cad/cdcac.h"
 #include "theory/arith/nl/cad/proof_checker.h"
+#include "theory/arith/nl/equality_substitution.h"
 
 namespace cvc5 {
 
@@ -104,6 +105,9 @@ class CadSolver: protected EnvObj
   InferenceManager& d_im;
   /** Reference to the non-linear model object */
   NlModel& d_model;
+  /** Utility to eliminate variables from simple equalities before going into
+   * the actual coverings solver */
+  EqualitySubstitution d_eqsubs;
 }; /* class CadSolver */
 
 }  // namespace nl
diff --git a/src/theory/arith/nl/equality_substitution.cpp b/src/theory/arith/nl/equality_substitution.cpp
new file mode 100644 (file)
index 0000000..9b3a79c
--- /dev/null
@@ -0,0 +1,183 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Gereon Kremer, Andrew Reynolds, Andres Noetzli
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Implementation of new non-linear solver.
+ */
+
+#include "theory/arith/nl/equality_substitution.h"
+
+#include "smt/env.h"
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+namespace nl {
+
+EqualitySubstitution::EqualitySubstitution(Env& env)
+    : EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>())
+{
+}
+void EqualitySubstitution::reset()
+{
+  d_substitutions = std::make_unique<SubstitutionMap>();
+  d_conflict.clear();
+  d_conflictMap.clear();
+  d_trackOrigin.clear();
+}
+
+std::vector<Node> EqualitySubstitution::eliminateEqualities(
+    const std::vector<Node>& assertions)
+{
+  Trace("nl-eqs") << "Input:" << std::endl;
+  for (const auto& a : assertions)
+  {
+    Trace("nl-eqs") << "\t" << a << std::endl;
+  }
+  std::set<TNode> tracker;
+  std::vector<Node> asserts = assertions;
+  std::vector<Node> next;
+
+  size_t last_size = 0;
+  while (asserts.size() != last_size)
+  {
+    last_size = asserts.size();
+    // collect all eliminations from original into d_substitutions
+    for (const auto& orig : asserts)
+    {
+      if (orig.getKind() != Kind::EQUAL) continue;
+      tracker.clear();
+      d_substitutions->invalidateCache();
+      Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker);
+      Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o
+                      << std::endl;
+      if (o.getKind() != Kind::EQUAL) continue;
+      Assert(o.getNumChildren() == 2);
+      for (size_t i = 0; i < 2; ++i)
+      {
+        const auto& l = o[i];
+        const auto& r = o[1 - i];
+        if (l.isConst()) continue;
+        if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue;
+        if (d_substitutions->hasSubstitution(l)) continue;
+        if (expr::hasSubterm(r, l, true)) continue;
+        Trace("nl-eqs") << "Found substitution " << l << " -> " << r
+                        << std::endl
+                        << " from " << o << " / " << orig << std::endl;
+        d_substitutions->addSubstitution(l, r);
+        d_trackOrigin.emplace(l, o);
+        if (o != orig)
+        {
+          addToConflictMap(o, orig, tracker);
+        }
+        break;
+      }
+    }
+
+    // simplify with subs from original into next
+    next.clear();
+    for (const auto& a : asserts)
+    {
+      tracker.clear();
+      d_substitutions->invalidateCache();
+      Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker);
+      if (simp.isConst())
+      {
+        if (simp.getConst<bool>())
+        {
+          continue;
+        }
+        Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl;
+        for (TNode t : tracker)
+        {
+          Trace("nl-eqs") << "Tracker has " << t << std::endl;
+          auto toit = d_trackOrigin.find(t);
+          Assert(toit != d_trackOrigin.end());
+          d_conflict.emplace_back(toit->second);
+        }
+        d_conflict.emplace_back(a);
+        postprocessConflict(d_conflict);
+        Trace("nl-eqs") << "Direct conflict: " << d_conflict << std::endl;
+        Trace("nl-eqs") << std::endl
+                        << d_conflict.size() << " vs "
+                        << std::distance(d_substitutions->begin(),
+                                         d_substitutions->end())
+                        << std::endl
+                        << std::endl;
+        return {};
+      }
+      if (simp != a)
+      {
+        Trace("nl-eqs") << "Simplified " << a << " to " << simp << std::endl;
+        addToConflictMap(simp, a, tracker);
+      }
+      next.emplace_back(simp);
+    }
+    asserts = std::move(next);
+  }
+  d_conflict.clear();
+  return asserts;
+}
+void EqualitySubstitution::postprocessConflict(
+    std::vector<Node>& conflict) const
+{
+  Trace("nl-eqs") << "Postprocessing " << conflict << std::endl;
+  std::set<Node> result;
+  for (const auto& c : conflict)
+  {
+    auto it = d_conflictMap.find(c);
+    if (it == d_conflictMap.end())
+    {
+      result.insert(c);
+    }
+    else
+    {
+      Trace("nl-eqs") << "Origin of " << c << ": " << it->second << std::endl;
+      result.insert(it->second.begin(), it->second.end());
+    }
+  }
+  conflict.clear();
+  conflict.insert(conflict.end(), result.begin(), result.end());
+  Trace("nl-eqs") << "-> " << conflict << std::endl;
+}
+void EqualitySubstitution::insertOrigins(std::set<Node>& dest,
+                                         const Node& n) const
+{
+  auto it = d_conflictMap.find(n);
+  if (it == d_conflictMap.end())
+  {
+    dest.insert(n);
+  }
+  else
+  {
+    dest.insert(it->second.begin(), it->second.end());
+  }
+}
+void EqualitySubstitution::addToConflictMap(const Node& n,
+                                            const Node& orig,
+                                            const std::set<TNode>& tracker)
+{
+  std::set<Node> origins;
+  insertOrigins(origins, orig);
+  for (const auto& t : tracker)
+  {
+    auto tit = d_trackOrigin.find(t);
+    Assert(tit != d_trackOrigin.end());
+    insertOrigins(origins, tit->second);
+  }
+  Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl;
+  d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end()));
+}
+
+}  // namespace nl
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/arith/nl/equality_substitution.h b/src/theory/arith/nl/equality_substitution.h
new file mode 100644 (file)
index 0000000..b095af8
--- /dev/null
@@ -0,0 +1,102 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Gereon Kremer
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * CAD-based solver based on https://arxiv.org/pdf/2003.05633.pdf.
+ */
+
+#ifndef CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H
+#define CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H
+
+#include <vector>
+
+#include "context/context.h"
+#include "expr/node.h"
+#include "expr/node_algorithm.h"
+#include "smt/env_obj.h"
+#include "theory/substitutions.h"
+#include "theory/theory.h"
+
+namespace cvc5 {
+namespace theory {
+namespace arith {
+namespace nl {
+
+/**
+ * This class is a general utility to eliminate variables from a set of
+ * assertions.
+ */
+class EqualitySubstitution : protected EnvObj
+{
+ public:
+  EqualitySubstitution(Env& env);
+  /** Reset this object */
+  void reset();
+
+  /**
+   * Eliminate variables using equalities from the set of assertions.
+   * Returns a set of assertions where some variables may have been eliminated.
+   * Substitutions for the eliminated variables can be obtained from
+   * getSubstitutions().
+   */
+  std::vector<Node> eliminateEqualities(const std::vector<Node>& assertions);
+  /**
+   * Can be called after eliminateEqualities(). Returns the substitutions that
+   * were found and eliminated.
+   */
+  const SubstitutionMap& getSubstitutions() const { return *d_substitutions; }
+  /**
+   * Can be called after eliminateEqualities(). Checks whether a direct conflict
+   * was found, that is an assertion simplified to false during
+   * eliminateEqualities().
+   */
+  bool hasConflict() const { return !d_conflict.empty(); }
+  /**
+   * Return the conflict found in eliminateEqualities() as a set of assertions
+   * that is a subset of the input assertions provided to eliminateEqualities().
+   */
+  const std::vector<Node>& getConflict() const { return d_conflict; }
+  /**
+   * Postprocess a conflict found in the result of eliminateEqualities.
+   * Replaces assertions within the conflict by their origins, i.e. the input
+   * assertions and the assertions that gave rise to the substitutions being
+   * used.
+   */
+  void postprocessConflict(std::vector<Node>& conflict) const;
+
+ private:
+  /** Utility method for addToConflictMap. Checks for n in d_conflictMap */
+  void insertOrigins(std::set<Node>& dest, const Node& n) const;
+  /** Add n -> { orig, *tracker } to the conflict map. The tracked nodes are
+   * first resolved using d_trackOrigin, and everything is run through
+   * insertOrigins to make sure that all origins are input assertions. */
+  void addToConflictMap(const Node& n,
+                        const Node& orig,
+                        const std::set<TNode>& tracker);
+
+  // The SubstitutionMap
+  std::unique_ptr<SubstitutionMap> d_substitutions;
+  // conflicting assertions, if a conflict was found
+  std::vector<Node> d_conflict;
+  // Maps a simplified assertion to the original assertion + set of original
+  // assertions used for substitutions
+  std::map<Node, std::vector<Node>> d_conflictMap;
+  // Maps substituted terms (what will end up in the tracker) to the equality
+  // from which the substitution was derived.
+  std::map<Node, Node> d_trackOrigin;
+};
+
+}  // namespace nl
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__ARITH__NL__EQUALITY_SUBSTITUTION_H */
index 77bb164a965b13aeeb2a9b1233046944ef3d051c..3f60f859649b958510996914172295ff0448b806 100644 (file)
@@ -353,45 +353,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS
   }
 
   // compute whether shared terms have correct values
-  unsigned num_shared_wrong_value = 0;
-  std::vector<Node> shared_term_value_splits;
-  // must ensure that shared terms are equal to their concrete value
-  Trace("nl-ext-mv") << "Shared terms : " << std::endl;
-  for (context::CDList<TNode>::const_iterator its =
-           d_containing.shared_terms_begin();
-       its != d_containing.shared_terms_end();
-       ++its)
-  {
-    TNode shared_term = *its;
-    // compute its value in the model, and its evaluation in the model
-    Node stv0 = d_model.computeConcreteModelValue(shared_term);
-    Node stv1 = d_model.computeAbstractModelValue(shared_term);
-    d_model.printModelValue("nl-ext-mv", shared_term);
-    if (stv0 != stv1)
-    {
-      num_shared_wrong_value++;
-      Trace("nl-ext-mv") << "Bad shared term value : " << shared_term
-                         << std::endl;
-      if (shared_term != stv0)
-      {
-        // split on the value, this is non-terminating in general, TODO :
-        // improve this
-        Node eq = shared_term.eqNode(stv0);
-        shared_term_value_splits.push_back(eq);
-      }
-      else
-      {
-        // this can happen for transcendental functions
-        // the problem is that we cannot evaluate transcendental functions
-        // (they don't have a rewriter that returns constants)
-        // thus, the actual value in their model can be themselves, hence we
-        // have no reference point to rule out the current model.  In this
-        // case, we may set incomplete below.
-      }
-    }
-  }
-  Trace("nl-ext-debug") << "     " << num_shared_wrong_value
-                        << " shared terms with wrong model value." << std::endl;
   bool needsRecheck;
   do
   {
@@ -402,9 +363,9 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS
     int complete_status = 1;
     // We require a check either if an assertion is false or a shared term has
     // a wrong value
-    if (!false_asserts.empty() || num_shared_wrong_value > 0)
+    if (!false_asserts.empty())
     {
-      complete_status = num_shared_wrong_value > 0 ? -1 : 0;
+      complete_status = 0;
       runStrategy(Theory::Effort::EFFORT_FULL, assertions, false_asserts, xts);
       if (d_im.hasSentLemma() || d_im.hasPendingLemma())
       {
@@ -446,40 +407,6 @@ Result::Sat NonlinearExtension::modelBasedRefinement(const std::set<Node>& termS
                         << std::endl;
         return Result::Sat::UNSAT;
       }
-      // resort to splitting on shared terms with their model value
-      // if we did not add any lemmas
-      if (num_shared_wrong_value > 0)
-      {
-        complete_status = -1;
-        if (!shared_term_value_splits.empty())
-        {
-          for (const Node& eq : shared_term_value_splits)
-          {
-            Node req = rewrite(eq);
-            Node literal = d_containing.getValuation().ensureLiteral(req);
-            d_containing.getOutputChannel().requirePhase(literal, true);
-            Trace("nl-ext-debug") << "Split on : " << literal << std::endl;
-            Node split = literal.orNode(literal.negate());
-            d_im.addPendingLemma(split,
-                                 InferenceId::ARITH_NL_SHARED_TERM_VALUE_SPLIT,
-                                 nullptr,
-                                 true);
-          }
-          if (d_im.hasWaitingLemma())
-          {
-            d_im.flushWaitingLemmas();
-            Trace("nl-ext") << "...added " << d_im.numPendingLemmas()
-                            << " shared term value split lemmas." << std::endl;
-            return Result::Sat::UNSAT;
-          }
-        }
-        else
-        {
-          // this can happen if we are trying to do theory combination with
-          // trancendental functions
-          // since their model value cannot even be computed exactly
-        }
-      }
 
       // we are incomplete
       if (options().arith.nlExt == options::NlExtMode::FULL
index b33e45129a3d7b0e2e80cbdd94b927d9abfd0a3d..a14841f67efc3edf71f59b8991bebeffc7b7061c 100644 (file)
@@ -172,10 +172,7 @@ void Strategy::initializeStrategy(const Options& options)
   one << InferStep::POW2_FULL << InferStep::BREAK;
   if (options.arith.nlCad)
   {
-    one << InferStep::CAD_INIT;
-  }
-  if (options.arith.nlCad)
-  {
+    one << InferStep::CAD_INIT << InferStep::BREAK;
     one << InferStep::CAD_FULL << InferStep::BREAK;
   }
 
index f910944817b2f9ae72adddc334407b7ed9649115..e495630467681480f684734f48d07dfa2f337915 100644 (file)
@@ -39,7 +39,7 @@ struct substitution_stack_element {
   }
 };/* struct substitution_stack_element */
 
-Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) {
+Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker) {
 
   Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << ")" << endl;
 
@@ -70,10 +70,17 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) {
     if (find2 != d_substitutions.end()) {
       Node rhs = (*find2).second;
       Assert(rhs != current);
-      internalSubstitute(rhs, cache);
-      d_substitutions[current] = cache[rhs];
+      internalSubstitute(rhs, cache, tracker);
+      if (tracker == nullptr)
+      {
+        d_substitutions[current] = cache[rhs];
+      }
       cache[current] = cache[rhs];
       toVisit.pop_back();
+      if (tracker != nullptr)
+      {
+        tracker->insert(current);
+      }
       continue;
     }
 
@@ -101,10 +108,14 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache) {
           if (find2 != d_substitutions.end()) {
             Node rhs = (*find2).second;
             Assert(rhs != result);
-            internalSubstitute(rhs, cache);
+            internalSubstitute(rhs, cache, tracker);
             d_substitutions[result] = cache[rhs];
             cache[result] = cache[rhs];
             result = cache[rhs];
+            if (tracker != nullptr)
+            {
+              tracker->insert(result);
+            }
           }
         }
       }
@@ -184,8 +195,8 @@ void SubstitutionMap::addSubstitutions(SubstitutionMap& subMap, bool invalidateC
   }
 }
 
-Node SubstitutionMap::apply(TNode t, Rewriter* r)
-{
+Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set<TNode>* tracker) {
+
   Debug("substitution") << "SubstitutionMap::apply(" << t << ")" << endl;
 
   // Setup the cache
@@ -196,7 +207,7 @@ Node SubstitutionMap::apply(TNode t, Rewriter* r)
   }
 
   // Perform the substitution
-  Node result = internalSubstitute(t, d_substitutionCache);
+  Node result = internalSubstitute(t, d_substitutionCache, tracker);
   Debug("substitution") << "SubstitutionMap::apply(" << t << ") => " << result << endl;
 
   if (r != nullptr)
index 7a3afcb1188ad18c379bab767ca6b02a9aec3864..2154c7fd53403506112e80a95bc6ffcdb9c087f6 100644 (file)
@@ -65,7 +65,7 @@ class SubstitutionMap
   bool d_cacheInvalidated;
 
   /** Internal method that performs substitution */
-  Node internalSubstitute(TNode t, NodeCache& cache);
+  Node internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker);
 
   /** Helper class to invalidate cache on user pop */
   class CacheInvalidator : public context::ContextNotifyObj
@@ -130,7 +130,7 @@ class SubstitutionMap
    * Apply the substitutions to the node, optionally rewrite if a non-null
    * Rewriter pointer is passed.
    */
-  Node apply(TNode t, Rewriter* r = nullptr);
+  Node apply(TNode t, Rewriter* r = nullptr, std::set<TNode>* tracker = nullptr);
 
   /**
    * Apply the substitutions to the node.
@@ -155,6 +155,10 @@ class SubstitutionMap
    */
   void print(std::ostream& out) const;
 
+  void invalidateCache() {
+    d_cacheInvalidated = true;
+  }
+
 }; /* class SubstitutionMap */
 
 inline std::ostream& operator << (std::ostream& out, const SubstitutionMap& subst) {
index ccb50c55da0964c4877f14c898299bd7334abf14..b49287c30f21a8080f8a82393071a64e5031cdb1 100644 (file)
@@ -1,6 +1,6 @@
 ; REQUIRES: poly
 ; COMMAND-LINE: --theoryof-mode=term --nl-icp
-; EXPECT: unknown
+; EXPECT: sat
 (set-logic QF_NRA)
 (set-option :check-proofs true)
 (declare-fun x () Real)
index bedc0209ba1d44eb0d1480ca120225d0e94fe596..ba49f23fe84c3ffc2014426f9227bd4d8df63087 100644 (file)
@@ -1,5 +1,5 @@
 ; COMMAND-LINE: --nl-ext=full --no-nl-ext-tf-tplanes --no-nl-ext-inc-prec
-; EXPECT: unknown
+; EXPECT: sat
 (set-logic UFNRAT)
 (declare-fun f (Real) Real)