Add additional check to avoid cyclic substitution (#7991)
authorGereon Kremer <gkremer@stanford.edu>
Wed, 2 Feb 2022 18:51:20 +0000 (10:51 -0800)
committerGitHub <noreply@github.com>
Wed, 2 Feb 2022 18:51:20 +0000 (18:51 +0000)
The substitutions we extract from equalities in the nonlinear solver would sometimes generate cyclic substitutions.
This PR tries harder to avoid such cases.
Fixes cvc5/cvc5-projects#444.

src/theory/arith/nl/equality_substitution.cpp
src/theory/substitutions.cpp
src/theory/substitutions.h
test/regress/CMakeLists.txt
test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 [new file with mode: 0644]

index 9b3a79cd4c4e7316af7e8778d068af5c55dd80d4..720ba7478398dd3d1e69e3c9b80eee08ee3e0346 100644 (file)
 #include "theory/arith/nl/equality_substitution.h"
 
 #include "smt/env.h"
+#include "theory/arith/arith_utilities.h"
 
 namespace cvc5 {
 namespace theory {
 namespace arith {
 namespace nl {
 
+namespace {
+struct ShouldTraverse : public SubstitutionMap::ShouldTraverseCallback
+{
+  bool operator()(TNode n) const override
+  {
+    switch (theory::kindToTheoryId(n.getKind()))
+    {
+      case TheoryId::THEORY_BOOL:
+      case TheoryId::THEORY_BUILTIN: return true;
+      case TheoryId::THEORY_ARITH: return !isTranscendentalKind(n.getKind());
+      default: return false;
+    }
+  }
+};
+}  // namespace
+
 EqualitySubstitution::EqualitySubstitution(Env& env)
     : EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>())
 {
@@ -37,14 +54,18 @@ void EqualitySubstitution::reset()
 std::vector<Node> EqualitySubstitution::eliminateEqualities(
     const std::vector<Node>& assertions)
 {
-  Trace("nl-eqs") << "Input:" << std::endl;
-  for (const auto& a : assertions)
+  if (Trace.isOn("nl-eqs"))
   {
-    Trace("nl-eqs") << "\t" << a << std::endl;
+    Trace("nl-eqs") << "Input:" << std::endl;
+    for (const auto& a : assertions)
+    {
+      Trace("nl-eqs") << "\t" << a << std::endl;
+    }
   }
   std::set<TNode> tracker;
   std::vector<Node> asserts = assertions;
   std::vector<Node> next;
+  const ShouldTraverse stc;
 
   size_t last_size = 0;
   while (asserts.size() != last_size)
@@ -56,9 +77,8 @@ std::vector<Node> EqualitySubstitution::eliminateEqualities(
       if (orig.getKind() != Kind::EQUAL) continue;
       tracker.clear();
       d_substitutions->invalidateCache();
-      Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker);
-      Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o
-                      << std::endl;
+      Node o =
+          d_substitutions->apply(orig, d_env.getRewriter(), &tracker, &stc);
       if (o.getKind() != Kind::EQUAL) continue;
       Assert(o.getNumChildren() == 2);
       for (size_t i = 0; i < 2; ++i)
@@ -68,7 +88,9 @@ std::vector<Node> EqualitySubstitution::eliminateEqualities(
         if (l.isConst()) continue;
         if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue;
         if (d_substitutions->hasSubstitution(l)) continue;
-        if (expr::hasSubterm(r, l, true)) continue;
+        if (expr::hasSubterm(r, l)) continue;
+        d_substitutions->invalidateCache();
+        if (expr::hasSubterm(d_substitutions->apply(r), l)) continue;
         Trace("nl-eqs") << "Found substitution " << l << " -> " << r
                         << std::endl
                         << " from " << o << " / " << orig << std::endl;
@@ -88,7 +110,8 @@ std::vector<Node> EqualitySubstitution::eliminateEqualities(
     {
       tracker.clear();
       d_substitutions->invalidateCache();
-      Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker);
+      Node simp =
+          d_substitutions->apply(a, d_env.getRewriter(), &tracker, &stc);
       if (simp.isConst())
       {
         if (simp.getConst<bool>())
@@ -124,6 +147,20 @@ std::vector<Node> EqualitySubstitution::eliminateEqualities(
     asserts = std::move(next);
   }
   d_conflict.clear();
+  if (Trace.isOn("nl-eqs"))
+  {
+    Trace("nl-eqs") << "Output:" << std::endl;
+    for (const auto& a : asserts)
+    {
+      Trace("nl-eqs") << "\t" << a << std::endl;
+    }
+    Trace("nl-eqs") << "Substitutions:" << std::endl;
+    for (const auto& subs : d_substitutions->getSubstitutions())
+    {
+      Trace("nl-eqs") << "\t" << subs.first << " -> " << subs.second
+                      << std::endl;
+    }
+  }
   return asserts;
 }
 void EqualitySubstitution::postprocessConflict(
@@ -173,7 +210,6 @@ void EqualitySubstitution::addToConflictMap(const Node& n,
     Assert(tit != d_trackOrigin.end());
     insertOrigins(origins, tit->second);
   }
-  Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl;
   d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end()));
 }
 
index 71612021e3dd65cd9d2dd857346b8f8086901ca4..4e1e219d44d03f212dcf5568fbccd65ec88a508d 100644 (file)
@@ -49,8 +49,11 @@ struct substitution_stack_element {
   }
 };/* struct substitution_stack_element */
 
-Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker) {
-
+Node SubstitutionMap::internalSubstitute(TNode t,
+                                         NodeCache& cache,
+                                         std::set<TNode>* tracker,
+                                         const ShouldTraverseCallback* stc)
+{
   Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << ")" << endl;
 
   if (d_substitutions.empty()) {
@@ -80,7 +83,7 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNo
     if (find2 != d_substitutions.end()) {
       Node rhs = (*find2).second;
       Assert(rhs != current);
-      internalSubstitute(rhs, cache, tracker);
+      internalSubstitute(rhs, cache, tracker, stc);
       if (tracker == nullptr)
       {
         d_substitutions[current] = cache[rhs];
@@ -118,7 +121,7 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNo
           if (find2 != d_substitutions.end()) {
             Node rhs = (*find2).second;
             Assert(rhs != result);
-            internalSubstitute(rhs, cache, tracker);
+            internalSubstitute(rhs, cache, tracker, stc);
             d_substitutions[result] = cache[rhs];
             cache[result] = cache[rhs];
             if (tracker != nullptr)
@@ -136,7 +139,11 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNo
     else
     {
       // Mark that we have added the children if any
-      if (current.getNumChildren() > 0 || current.getMetaKind() == kind::metakind::PARAMETERIZED) {
+      bool recurse = (stc == nullptr || (*stc)(current));
+      if (recurse
+          && (current.getNumChildren() > 0
+              || current.getMetaKind() == kind::metakind::PARAMETERIZED))
+      {
         stackHead.d_children_added = true;
         // We need to add the operator, if any
         if(current.getMetaKind() == kind::metakind::PARAMETERIZED) {
@@ -154,7 +161,9 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNo
             toVisit.push_back(childNode);
           }
         }
-      } else {
+      }
+      else
+      {
         // No children, so we're done
         Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << "): setting " << current << " -> " << current << endl;
         cache[current] = current;
@@ -165,8 +174,7 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNo
 
   // Return the substituted version
   return cache[t];
-}/* SubstitutionMap::internalSubstitute() */
-
+} /* SubstitutionMap::internalSubstitute() */
 
 void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache)
 {
@@ -205,8 +213,11 @@ void SubstitutionMap::addSubstitutions(SubstitutionMap& subMap, bool invalidateC
   }
 }
 
-Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set<TNode>* tracker) {
-
+Node SubstitutionMap::apply(TNode t,
+                            Rewriter* r,
+                            std::set<TNode>* tracker,
+                            const ShouldTraverseCallback* stc)
+{
   Debug("substitution") << "SubstitutionMap::apply(" << t << ")" << endl;
 
   // Setup the cache
@@ -217,7 +228,7 @@ Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set<TNode>* tracker) {
   }
 
   // Perform the substitution
-  Node result = internalSubstitute(t, d_substitutionCache, tracker);
+  Node result = internalSubstitute(t, d_substitutionCache, tracker, stc);
   Debug("substitution") << "SubstitutionMap::apply(" << t << ") => " << result << endl;
 
   if (r != nullptr)
index f6d4bdcf01ae7af957c75d740708fbc48dc81ff6..1029c7a0ae3d05a70ca02de88be6b18156f67dde 100644 (file)
@@ -50,6 +50,12 @@ class SubstitutionMap
   typedef NodeMap::iterator iterator;
   typedef NodeMap::const_iterator const_iterator;
 
+  struct ShouldTraverseCallback
+  {
+    virtual bool operator()(TNode n) const = 0;
+    virtual ~ShouldTraverseCallback() {}
+  };
+
  private:
   typedef std::unordered_map<Node, Node> NodeCache;
   /** A dummy context used by this class if none is provided */
@@ -65,7 +71,10 @@ class SubstitutionMap
   bool d_cacheInvalidated;
 
   /** Internal method that performs substitution */
-  Node internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker);
+  Node internalSubstitute(TNode t,
+                          NodeCache& cache,
+                          std::set<TNode>* tracker,
+                          const ShouldTraverseCallback* stc);
 
   /** Helper class to invalidate cache on user pop */
   class CacheInvalidator : public context::ContextNotifyObj
@@ -134,7 +143,10 @@ class SubstitutionMap
    * Apply the substitutions to the node, optionally rewrite if a non-null
    * Rewriter pointer is passed.
    */
-  Node apply(TNode t, Rewriter* r = nullptr, std::set<TNode>* tracker = nullptr);
+  Node apply(TNode t,
+             Rewriter* r = nullptr,
+             std::set<TNode>* tracker = nullptr,
+             const ShouldTraverseCallback* stc = nullptr);
 
   /**
    * Apply the substitutions to the node.
index 6c7862d1e03c50565420c519587c95239b4b32f7..83cc6f1c8e0be52919c6d79eaaa5ff0c8f250ec2 100644 (file)
@@ -779,6 +779,7 @@ set(regress_0_tests
   regress0/nl/pow2-pow.smt2
   regress0/nl/pow2-pow-isabelle.smt2
   regress0/nl/proj-issue-348.smt2
+  regress0/nl/proj-issue-444-memout-eqelim.smt2
   regress0/nl/real-as-int.smt2
   regress0/nl/real-div-ufnra.smt2
   regress0/nl/sin-cos-346-b-chunk-0169.smt2
diff --git a/test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 b/test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2
new file mode 100644 (file)
index 0000000..479dc39
--- /dev/null
@@ -0,0 +1,12 @@
+; REQUIRES: poly
+; EXPECT: sat
+(set-logic QF_UFNRA)
+(declare-fun w (Real) Real)
+(declare-fun m (Real) Real)
+(declare-fun t (Real) Bool)
+(declare-fun u (Real) Real)
+(assert (= (m 1) (w 0)))
+(assert (not (t 0.0)))
+(assert (= (+ 1 (w 1)) (* (u 1.0) (m (+ 1 (w 1))))))
+(assert (= (t 0) (= (w 1) (* (u 1) (u 0)))))
+(check-sat)