From 9617530df28b3d0ae75da349e956ca427cf02c75 Mon Sep 17 00:00:00 2001 From: Gereon Kremer Date: Wed, 2 Feb 2022 10:51:20 -0800 Subject: [PATCH] Add additional check to avoid cyclic substitution (#7991) The substitutions we extract from equalities in the nonlinear solver would sometimes generate cyclic substitutions. This PR tries harder to avoid such cases. Fixes cvc5/cvc5-projects#444. --- src/theory/arith/nl/equality_substitution.cpp | 54 +++++++++++++++---- src/theory/substitutions.cpp | 33 ++++++++---- src/theory/substitutions.h | 16 +++++- test/regress/CMakeLists.txt | 1 + .../nl/proj-issue-444-memout-eqelim.smt2 | 12 +++++ 5 files changed, 94 insertions(+), 22 deletions(-) create mode 100644 test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 diff --git a/src/theory/arith/nl/equality_substitution.cpp b/src/theory/arith/nl/equality_substitution.cpp index 9b3a79cd4..720ba7478 100644 --- a/src/theory/arith/nl/equality_substitution.cpp +++ b/src/theory/arith/nl/equality_substitution.cpp @@ -16,12 +16,29 @@ #include "theory/arith/nl/equality_substitution.h" #include "smt/env.h" +#include "theory/arith/arith_utilities.h" namespace cvc5 { namespace theory { namespace arith { namespace nl { +namespace { +struct ShouldTraverse : public SubstitutionMap::ShouldTraverseCallback +{ + bool operator()(TNode n) const override + { + switch (theory::kindToTheoryId(n.getKind())) + { + case TheoryId::THEORY_BOOL: + case TheoryId::THEORY_BUILTIN: return true; + case TheoryId::THEORY_ARITH: return !isTranscendentalKind(n.getKind()); + default: return false; + } + } +}; +} // namespace + EqualitySubstitution::EqualitySubstitution(Env& env) : EnvObj(env), d_substitutions(std::make_unique()) { @@ -37,14 +54,18 @@ void EqualitySubstitution::reset() std::vector EqualitySubstitution::eliminateEqualities( const std::vector& assertions) { - Trace("nl-eqs") << "Input:" << std::endl; - for (const auto& a : assertions) + if (Trace.isOn("nl-eqs")) { - Trace("nl-eqs") << "\t" << a << std::endl; + Trace("nl-eqs") << "Input:" << std::endl; + for (const auto& a : assertions) + { + Trace("nl-eqs") << "\t" << a << std::endl; + } } std::set tracker; std::vector asserts = assertions; std::vector next; + const ShouldTraverse stc; size_t last_size = 0; while (asserts.size() != last_size) @@ -56,9 +77,8 @@ std::vector EqualitySubstitution::eliminateEqualities( if (orig.getKind() != Kind::EQUAL) continue; tracker.clear(); d_substitutions->invalidateCache(); - Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker); - Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o - << std::endl; + Node o = + d_substitutions->apply(orig, d_env.getRewriter(), &tracker, &stc); if (o.getKind() != Kind::EQUAL) continue; Assert(o.getNumChildren() == 2); for (size_t i = 0; i < 2; ++i) @@ -68,7 +88,9 @@ std::vector EqualitySubstitution::eliminateEqualities( if (l.isConst()) continue; if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue; if (d_substitutions->hasSubstitution(l)) continue; - if (expr::hasSubterm(r, l, true)) continue; + if (expr::hasSubterm(r, l)) continue; + d_substitutions->invalidateCache(); + if (expr::hasSubterm(d_substitutions->apply(r), l)) continue; Trace("nl-eqs") << "Found substitution " << l << " -> " << r << std::endl << " from " << o << " / " << orig << std::endl; @@ -88,7 +110,8 @@ std::vector EqualitySubstitution::eliminateEqualities( { tracker.clear(); d_substitutions->invalidateCache(); - Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker); + Node simp = + d_substitutions->apply(a, d_env.getRewriter(), &tracker, &stc); if (simp.isConst()) { if (simp.getConst()) @@ -124,6 +147,20 @@ std::vector EqualitySubstitution::eliminateEqualities( asserts = std::move(next); } d_conflict.clear(); + if (Trace.isOn("nl-eqs")) + { + Trace("nl-eqs") << "Output:" << std::endl; + for (const auto& a : asserts) + { + Trace("nl-eqs") << "\t" << a << std::endl; + } + Trace("nl-eqs") << "Substitutions:" << std::endl; + for (const auto& subs : d_substitutions->getSubstitutions()) + { + Trace("nl-eqs") << "\t" << subs.first << " -> " << subs.second + << std::endl; + } + } return asserts; } void EqualitySubstitution::postprocessConflict( @@ -173,7 +210,6 @@ void EqualitySubstitution::addToConflictMap(const Node& n, Assert(tit != d_trackOrigin.end()); insertOrigins(origins, tit->second); } - Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl; d_conflictMap.emplace(n, std::vector(origins.begin(), origins.end())); } diff --git a/src/theory/substitutions.cpp b/src/theory/substitutions.cpp index 71612021e..4e1e219d4 100644 --- a/src/theory/substitutions.cpp +++ b/src/theory/substitutions.cpp @@ -49,8 +49,11 @@ struct substitution_stack_element { } };/* struct substitution_stack_element */ -Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set* tracker) { - +Node SubstitutionMap::internalSubstitute(TNode t, + NodeCache& cache, + std::set* tracker, + const ShouldTraverseCallback* stc) +{ Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << ")" << endl; if (d_substitutions.empty()) { @@ -80,7 +83,7 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set 0 || current.getMetaKind() == kind::metakind::PARAMETERIZED) { + bool recurse = (stc == nullptr || (*stc)(current)); + if (recurse + && (current.getNumChildren() > 0 + || current.getMetaKind() == kind::metakind::PARAMETERIZED)) + { stackHead.d_children_added = true; // We need to add the operator, if any if(current.getMetaKind() == kind::metakind::PARAMETERIZED) { @@ -154,7 +161,9 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set " << current << endl; cache[current] = current; @@ -165,8 +174,7 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set* tracker) { - +Node SubstitutionMap::apply(TNode t, + Rewriter* r, + std::set* tracker, + const ShouldTraverseCallback* stc) +{ Debug("substitution") << "SubstitutionMap::apply(" << t << ")" << endl; // Setup the cache @@ -217,7 +228,7 @@ Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set* tracker) { } // Perform the substitution - Node result = internalSubstitute(t, d_substitutionCache, tracker); + Node result = internalSubstitute(t, d_substitutionCache, tracker, stc); Debug("substitution") << "SubstitutionMap::apply(" << t << ") => " << result << endl; if (r != nullptr) diff --git a/src/theory/substitutions.h b/src/theory/substitutions.h index f6d4bdcf0..1029c7a0a 100644 --- a/src/theory/substitutions.h +++ b/src/theory/substitutions.h @@ -50,6 +50,12 @@ class SubstitutionMap typedef NodeMap::iterator iterator; typedef NodeMap::const_iterator const_iterator; + struct ShouldTraverseCallback + { + virtual bool operator()(TNode n) const = 0; + virtual ~ShouldTraverseCallback() {} + }; + private: typedef std::unordered_map NodeCache; /** A dummy context used by this class if none is provided */ @@ -65,7 +71,10 @@ class SubstitutionMap bool d_cacheInvalidated; /** Internal method that performs substitution */ - Node internalSubstitute(TNode t, NodeCache& cache, std::set* tracker); + Node internalSubstitute(TNode t, + NodeCache& cache, + std::set* tracker, + const ShouldTraverseCallback* stc); /** Helper class to invalidate cache on user pop */ class CacheInvalidator : public context::ContextNotifyObj @@ -134,7 +143,10 @@ class SubstitutionMap * Apply the substitutions to the node, optionally rewrite if a non-null * Rewriter pointer is passed. */ - Node apply(TNode t, Rewriter* r = nullptr, std::set* tracker = nullptr); + Node apply(TNode t, + Rewriter* r = nullptr, + std::set* tracker = nullptr, + const ShouldTraverseCallback* stc = nullptr); /** * Apply the substitutions to the node. diff --git a/test/regress/CMakeLists.txt b/test/regress/CMakeLists.txt index 6c7862d1e..83cc6f1c8 100644 --- a/test/regress/CMakeLists.txt +++ b/test/regress/CMakeLists.txt @@ -779,6 +779,7 @@ set(regress_0_tests regress0/nl/pow2-pow.smt2 regress0/nl/pow2-pow-isabelle.smt2 regress0/nl/proj-issue-348.smt2 + regress0/nl/proj-issue-444-memout-eqelim.smt2 regress0/nl/real-as-int.smt2 regress0/nl/real-div-ufnra.smt2 regress0/nl/sin-cos-346-b-chunk-0169.smt2 diff --git a/test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 b/test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 new file mode 100644 index 000000000..479dc3905 --- /dev/null +++ b/test/regress/regress0/nl/proj-issue-444-memout-eqelim.smt2 @@ -0,0 +1,12 @@ +; REQUIRES: poly +; EXPECT: sat +(set-logic QF_UFNRA) +(declare-fun w (Real) Real) +(declare-fun m (Real) Real) +(declare-fun t (Real) Bool) +(declare-fun u (Real) Real) +(assert (= (m 1) (w 0))) +(assert (not (t 0.0))) +(assert (= (+ 1 (w 1)) (* (u 1.0) (m (+ 1 (w 1)))))) +(assert (= (t 0) (= (w 1) (* (u 1) (u 0))))) +(check-sat) -- 2.30.2