The substitutions we extract from equalities in the nonlinear solver would sometimes generate cyclic substitutions.
This PR tries harder to avoid such cases.
Fixes cvc5/cvc5-projects#444.
#include "theory/arith/nl/equality_substitution.h"
#include "smt/env.h"
+#include "theory/arith/arith_utilities.h"
namespace cvc5 {
namespace theory {
namespace arith {
namespace nl {
+namespace {
+struct ShouldTraverse : public SubstitutionMap::ShouldTraverseCallback
+{
+ bool operator()(TNode n) const override
+ {
+ switch (theory::kindToTheoryId(n.getKind()))
+ {
+ case TheoryId::THEORY_BOOL:
+ case TheoryId::THEORY_BUILTIN: return true;
+ case TheoryId::THEORY_ARITH: return !isTranscendentalKind(n.getKind());
+ default: return false;
+ }
+ }
+};
+} // namespace
+
EqualitySubstitution::EqualitySubstitution(Env& env)
: EnvObj(env), d_substitutions(std::make_unique<SubstitutionMap>())
{
std::vector<Node> EqualitySubstitution::eliminateEqualities(
const std::vector<Node>& assertions)
{
- Trace("nl-eqs") << "Input:" << std::endl;
- for (const auto& a : assertions)
+ if (Trace.isOn("nl-eqs"))
{
- Trace("nl-eqs") << "\t" << a << std::endl;
+ Trace("nl-eqs") << "Input:" << std::endl;
+ for (const auto& a : assertions)
+ {
+ Trace("nl-eqs") << "\t" << a << std::endl;
+ }
}
std::set<TNode> tracker;
std::vector<Node> asserts = assertions;
std::vector<Node> next;
+ const ShouldTraverse stc;
size_t last_size = 0;
while (asserts.size() != last_size)
if (orig.getKind() != Kind::EQUAL) continue;
tracker.clear();
d_substitutions->invalidateCache();
- Node o = d_substitutions->apply(orig, d_env.getRewriter(), &tracker);
- Trace("nl-eqs") << "Simplified for subst " << orig << " -> " << o
- << std::endl;
+ Node o =
+ d_substitutions->apply(orig, d_env.getRewriter(), &tracker, &stc);
if (o.getKind() != Kind::EQUAL) continue;
Assert(o.getNumChildren() == 2);
for (size_t i = 0; i < 2; ++i)
if (l.isConst()) continue;
if (!Theory::isLeafOf(l, TheoryId::THEORY_ARITH)) continue;
if (d_substitutions->hasSubstitution(l)) continue;
- if (expr::hasSubterm(r, l, true)) continue;
+ if (expr::hasSubterm(r, l)) continue;
+ d_substitutions->invalidateCache();
+ if (expr::hasSubterm(d_substitutions->apply(r), l)) continue;
Trace("nl-eqs") << "Found substitution " << l << " -> " << r
<< std::endl
<< " from " << o << " / " << orig << std::endl;
{
tracker.clear();
d_substitutions->invalidateCache();
- Node simp = d_substitutions->apply(a, d_env.getRewriter(), &tracker);
+ Node simp =
+ d_substitutions->apply(a, d_env.getRewriter(), &tracker, &stc);
if (simp.isConst())
{
if (simp.getConst<bool>())
asserts = std::move(next);
}
d_conflict.clear();
+ if (Trace.isOn("nl-eqs"))
+ {
+ Trace("nl-eqs") << "Output:" << std::endl;
+ for (const auto& a : asserts)
+ {
+ Trace("nl-eqs") << "\t" << a << std::endl;
+ }
+ Trace("nl-eqs") << "Substitutions:" << std::endl;
+ for (const auto& subs : d_substitutions->getSubstitutions())
+ {
+ Trace("nl-eqs") << "\t" << subs.first << " -> " << subs.second
+ << std::endl;
+ }
+ }
return asserts;
}
void EqualitySubstitution::postprocessConflict(
Assert(tit != d_trackOrigin.end());
insertOrigins(origins, tit->second);
}
- Trace("nl-eqs") << "ConflictMap: " << n << " -> " << origins << std::endl;
d_conflictMap.emplace(n, std::vector<Node>(origins.begin(), origins.end()));
}
}
};/* struct substitution_stack_element */
-Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker) {
-
+Node SubstitutionMap::internalSubstitute(TNode t,
+ NodeCache& cache,
+ std::set<TNode>* tracker,
+ const ShouldTraverseCallback* stc)
+{
Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << ")" << endl;
if (d_substitutions.empty()) {
if (find2 != d_substitutions.end()) {
Node rhs = (*find2).second;
Assert(rhs != current);
- internalSubstitute(rhs, cache, tracker);
+ internalSubstitute(rhs, cache, tracker, stc);
if (tracker == nullptr)
{
d_substitutions[current] = cache[rhs];
if (find2 != d_substitutions.end()) {
Node rhs = (*find2).second;
Assert(rhs != result);
- internalSubstitute(rhs, cache, tracker);
+ internalSubstitute(rhs, cache, tracker, stc);
d_substitutions[result] = cache[rhs];
cache[result] = cache[rhs];
if (tracker != nullptr)
else
{
// Mark that we have added the children if any
- if (current.getNumChildren() > 0 || current.getMetaKind() == kind::metakind::PARAMETERIZED) {
+ bool recurse = (stc == nullptr || (*stc)(current));
+ if (recurse
+ && (current.getNumChildren() > 0
+ || current.getMetaKind() == kind::metakind::PARAMETERIZED))
+ {
stackHead.d_children_added = true;
// We need to add the operator, if any
if(current.getMetaKind() == kind::metakind::PARAMETERIZED) {
toVisit.push_back(childNode);
}
}
- } else {
+ }
+ else
+ {
// No children, so we're done
Debug("substitution::internal") << "SubstitutionMap::internalSubstitute(" << t << "): setting " << current << " -> " << current << endl;
cache[current] = current;
// Return the substituted version
return cache[t];
-}/* SubstitutionMap::internalSubstitute() */
-
+} /* SubstitutionMap::internalSubstitute() */
void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache)
{
}
}
-Node SubstitutionMap::apply(TNode t, Rewriter* r, std::set<TNode>* tracker) {
-
+Node SubstitutionMap::apply(TNode t,
+ Rewriter* r,
+ std::set<TNode>* tracker,
+ const ShouldTraverseCallback* stc)
+{
Debug("substitution") << "SubstitutionMap::apply(" << t << ")" << endl;
// Setup the cache
}
// Perform the substitution
- Node result = internalSubstitute(t, d_substitutionCache, tracker);
+ Node result = internalSubstitute(t, d_substitutionCache, tracker, stc);
Debug("substitution") << "SubstitutionMap::apply(" << t << ") => " << result << endl;
if (r != nullptr)
typedef NodeMap::iterator iterator;
typedef NodeMap::const_iterator const_iterator;
+ struct ShouldTraverseCallback
+ {
+ virtual bool operator()(TNode n) const = 0;
+ virtual ~ShouldTraverseCallback() {}
+ };
+
private:
typedef std::unordered_map<Node, Node> NodeCache;
/** A dummy context used by this class if none is provided */
bool d_cacheInvalidated;
/** Internal method that performs substitution */
- Node internalSubstitute(TNode t, NodeCache& cache, std::set<TNode>* tracker);
+ Node internalSubstitute(TNode t,
+ NodeCache& cache,
+ std::set<TNode>* tracker,
+ const ShouldTraverseCallback* stc);
/** Helper class to invalidate cache on user pop */
class CacheInvalidator : public context::ContextNotifyObj
* Apply the substitutions to the node, optionally rewrite if a non-null
* Rewriter pointer is passed.
*/
- Node apply(TNode t, Rewriter* r = nullptr, std::set<TNode>* tracker = nullptr);
+ Node apply(TNode t,
+ Rewriter* r = nullptr,
+ std::set<TNode>* tracker = nullptr,
+ const ShouldTraverseCallback* stc = nullptr);
/**
* Apply the substitutions to the node.
regress0/nl/pow2-pow.smt2
regress0/nl/pow2-pow-isabelle.smt2
regress0/nl/proj-issue-348.smt2
+ regress0/nl/proj-issue-444-memout-eqelim.smt2
regress0/nl/real-as-int.smt2
regress0/nl/real-div-ufnra.smt2
regress0/nl/sin-cos-346-b-chunk-0169.smt2
--- /dev/null
+; REQUIRES: poly
+; EXPECT: sat
+(set-logic QF_UFNRA)
+(declare-fun w (Real) Real)
+(declare-fun m (Real) Real)
+(declare-fun t (Real) Bool)
+(declare-fun u (Real) Real)
+(assert (= (m 1) (w 0)))
+(assert (not (t 0.0)))
+(assert (= (+ 1 (w 1)) (* (u 1.0) (m (+ 1 (w 1))))))
+(assert (= (t 0) (= (w 1) (* (u 1) (u 0)))))
+(check-sat)