From 0c6249a1b2177fda94526b66510474f2cb01a411 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Fri, 28 Aug 2020 13:01:55 -0500 Subject: [PATCH] (new theory) Update TheoryFP to the new interface (#4953) This updates the theory of floating points to the new interface (see #4929). Notice that TheoryFP was not adding trigger terms to its equality engine (which should be done during notifySharedTerm), and thus was not propagating equalities between shared terms in combined theories. This PR updates its notifySharedTerm method to the default one. FYI @martin-cs --- src/theory/fp/theory_fp.cpp | 149 ++++++++++++++---------------------- src/theory/fp/theory_fp.h | 32 ++++++-- 2 files changed, 85 insertions(+), 96 deletions(-) diff --git a/src/theory/fp/theory_fp.cpp b/src/theory/fp/theory_fp.cpp index 0c5a92572..4c59b1c06 100644 --- a/src/theory/fp/theory_fp.cpp +++ b/src/theory/fp/theory_fp.cpp @@ -110,7 +110,6 @@ TheoryFp::TheoryFp(context::Context* c, d_registeredTerms(u), d_conv(u), d_expansionRequested(false), - d_conflict(c, false), d_conflictNode(c, Node::null()), d_minMap(u), d_maxMap(u), @@ -908,15 +907,6 @@ void TheoryFp::preRegisterTerm(TNode node) return; } -void TheoryFp::notifySharedTerm(TNode node) -{ - Trace("fp-addSharedTerm") - << "TheoryFp::notifySharedTerm(): " << node << std::endl; - // A system-wide invariant; terms must be registered before they are shared - Assert(isRegistered(node)); - return; -} - void TheoryFp::handleLemma(Node node) { Trace("fp") << "TheoryFp::handleLemma(): asserting " << node << std::endl; // Preprocess has to be true because it contains embedded ITEs @@ -926,83 +916,41 @@ void TheoryFp::handleLemma(Node node) { return; } -bool TheoryFp::handlePropagation(TNode node) { - Trace("fp") << "TheoryFp::handlePropagation(): propagate " << node - << std::endl; +bool TheoryFp::propagateLit(TNode node) +{ + Trace("fp") << "TheoryFp::propagateLit(): propagate " << node << std::endl; bool stat = d_out->propagate(node); if (!stat) { - d_conflict = true; + d_state.notifyInConflict(); } return stat; } -void TheoryFp::handleConflict(TNode node) { - Trace("fp") << "TheoryFp::handleConflict(): conflict detected " << node - << std::endl; +void TheoryFp::conflictEqConstantMerge(TNode t1, TNode t2) +{ + std::vector assumptions; + d_equalityEngine->explainEquality(t1, t2, true, assumptions); + + Node conflict = helper::buildConjunct(assumptions); + Trace("fp") << "TheoryFp::conflictEqConstantMerge(): conflict detected " + << conflict << std::endl; - d_conflictNode = node; - d_conflict = true; - d_out->conflict(node); + d_conflictNode = conflict; + d_state.notifyInConflict(); + d_out->conflict(conflict); return; } -void TheoryFp::check(Effort level) { - Trace("fp") << "TheoryFp::check(): started at effort level " << level - << std::endl; - - while (!done() && !d_conflict) { - // Get all the assertions - Assertion assertion = get(); - TNode fact = assertion.d_assertion; - - Debug("fp") << "TheoryFp::check(): processing " << fact << std::endl; - - // Only handle equalities; the rest should be handled by - // the bit-vector theory - - bool negated = fact.getKind() == kind::NOT; - TNode predicate = negated ? fact[0] : fact; - - if (predicate.getKind() == kind::EQUAL) { - Assert(!(predicate[0].getType().isFloatingPoint() - || predicate[0].getType().isRoundingMode()) - || isRegistered(predicate[0])); - Assert(!(predicate[1].getType().isFloatingPoint() - || predicate[1].getType().isRoundingMode()) - || isRegistered(predicate[1])); - registerTerm(predicate); // Needed for float equalities - - if (negated) { - Debug("fp-eq") << "TheoryFp::check(): adding dis-equality " << fact[0] - << std::endl; - d_equalityEngine->assertEquality(predicate, false, fact); - } else { - Debug("fp-eq") << "TheoryFp::check(): adding equality " << fact - << std::endl; - d_equalityEngine->assertEquality(predicate, true, fact); - } - } else { - // A system-wide invariant; predicates are registered before they are - // asserted - Assert(isRegistered(predicate)); - - if (d_equalityEngine->isFunctionKind(predicate.getKind())) - { - Debug("fp-eq") << "TheoryFp::check(): adding predicate " << predicate - << " is " << !negated << std::endl; - d_equalityEngine->assertPredicate(predicate, !negated, fact); - } - } - } - +void TheoryFp::postCheck(Effort level) +{ // Resolve the abstractions for the conversion lemmas if (level == EFFORT_LAST_CALL) { Trace("fp") << "TheoryFp::check(): checking abstractions" << std::endl; - TheoryModel *m = getValuation().getModel(); + TheoryModel* m = getValuation().getModel(); bool lemmaAdded = false; for (abstractionMapType::const_iterator i = abstractionMap.begin(); @@ -1017,11 +965,35 @@ void TheoryFp::check(Effort level) { } Trace("fp") << "TheoryFp::check(): completed" << std::endl; - /* Checking should be handled by the bit-vector engine */ - return; +} -} /* TheoryFp::check() */ +bool TheoryFp::preNotifyFact( + TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal) +{ + if (atom.getKind() == kind::EQUAL) + { + Assert(!(atom[0].getType().isFloatingPoint() + || atom[0].getType().isRoundingMode()) + || isRegistered(atom[0])); + Assert(!(atom[1].getType().isFloatingPoint() + || atom[1].getType().isRoundingMode()) + || isRegistered(atom[1])); + registerTerm(atom); // Needed for float equalities + } + else + { + // A system-wide invariant; predicates are registered before they are + // asserted + Assert(isRegistered(atom)); + + if (!d_equalityEngine->isFunctionKind(atom.getKind())) + { + return true; + } + } + return false; +} TrustNode TheoryFp::explain(TNode n) { @@ -1047,16 +1019,20 @@ Node TheoryFp::getModelValue(TNode var) { return d_conv.getValue(d_valuation, var); } -bool TheoryFp::collectModelInfo(TheoryModel *m) +bool TheoryFp::collectModelInfo(TheoryModel* m) { std::set relevantTerms; - - Trace("fp-collectModelInfo") - << "TheoryFp::collectModelInfo(): begin" << std::endl; - // Work out which variables are needed computeRelevantTerms(relevantTerms); + // this override behavior to not assert equality engine + return collectModelValues(m, relevantTerms); +} +bool TheoryFp::collectModelValues(TheoryModel* m, + const std::set& relevantTerms) +{ + Trace("fp-collectModelInfo") + << "TheoryFp::collectModelInfo(): begin" << std::endl; if (Trace.isOn("fp-collectModelInfo")) { for (std::set::const_iterator i(relevantTerms.begin()); i != relevantTerms.end(); ++i) { @@ -1153,9 +1129,9 @@ bool TheoryFp::NotifyClass::eqNotifyTriggerPredicate(TNode predicate, << predicate << " is " << value << std::endl; if (value) { - return d_theorySolver.handlePropagation(predicate); + return d_theorySolver.propagateLit(predicate); } - return d_theorySolver.handlePropagation(predicate.notNode()); + return d_theorySolver.propagateLit(predicate.notNode()); } bool TheoryFp::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, @@ -1164,22 +1140,15 @@ bool TheoryFp::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, << t1 << (value ? " = " : " != ") << t2 << std::endl; if (value) { - return d_theorySolver.handlePropagation(t1.eqNode(t2)); - } else { - return d_theorySolver.handlePropagation(t1.eqNode(t2).notNode()); + return d_theorySolver.propagateLit(t1.eqNode(t2)); } + return d_theorySolver.propagateLit(t1.eqNode(t2).notNode()); } void TheoryFp::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) { Debug("fp-eq") << "TheoryFp::eqNotifyConstantTermMerge(): call back as " << t1 << " = " << t2 << std::endl; - - std::vector assumptions; - d_theorySolver.d_equalityEngine->explainEquality(t1, t2, true, assumptions); - - Node conflict = helper::buildConjunct(assumptions); - - d_theorySolver.handleConflict(conflict); + d_theorySolver.conflictEqConstantMerge(t1, t2); } } // namespace fp diff --git a/src/theory/fp/theory_fp.h b/src/theory/fp/theory_fp.h index 79ece7bce..2ef3b3f35 100644 --- a/src/theory/fp/theory_fp.h +++ b/src/theory/fp/theory_fp.h @@ -58,15 +58,28 @@ class TheoryFp : public Theory { TrustNode expandDefinition(Node node) override; void preRegisterTerm(TNode node) override; - void notifySharedTerm(TNode node) override; TrustNode ppRewrite(TNode node) override; - void check(Effort) override; - + //--------------------------------- standard check + /** Do we need a check call at last call effort? */ bool needsCheckLastEffort() override { return true; } + /** Post-check, called after the fact queue of the theory is processed. */ + void postCheck(Effort level) override; + /** Pre-notify fact, return true if processed. */ + bool preNotifyFact(TNode atom, + bool pol, + TNode fact, + bool isPrereg, + bool isInternal) override; + //--------------------------------- end standard check + Node getModelValue(TNode var) override; bool collectModelInfo(TheoryModel* m) override; + /** Collect model values in m based on the relevant terms given by + * relevantTerms */ + bool collectModelValues(TheoryModel* m, + const std::set& relevantTerms) override; std::string identify() const override { return "THEORY_FP"; } @@ -108,10 +121,17 @@ class TheoryFp : public Theory { /** Interaction with the rest of the solver **/ void handleLemma(Node node); - bool handlePropagation(TNode node); - void handleConflict(TNode node); + /** + * Called when literal node is inferred by the equality engine. This + * propagates node on the output channel. + */ + bool propagateLit(TNode node); + /** + * Called when two constants t1 and t2 merge in the equality engine. This + * sends a conflict on the output channel. + */ + void conflictEqConstantMerge(TNode t1, TNode t2); - context::CDO d_conflict; context::CDO d_conflictNode; typedef context::CDHashMap -- 2.30.2