From e2a64ae3e03ade771363df90dfa3f50b87a9205a Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Wed, 16 Sep 2020 10:21:40 -0500 Subject: [PATCH] Refactor collectModelInfo in TheoryArith (#5027) This is work towards updating the arithmetic solver to the new standard, and in particular isolating TheoryArithPrivate as the "linear solver", and TheoryArith as the overall approach for arithmetic. This transfers ownership of the non-linear extension from TheoryArithPrivate to TheoryArith. The former still has a pointer to the non-linear extension, which will be removed with further refactoring. This PR additionally moves the code that handles the interplay of the non-linear extension in TheoryArithPrivate::collectModelInfo to TheoryArith, and simplifies the model interface for TheoryArithPrivate. --- src/theory/arith/theory_arith.cpp | 59 ++++++++++++++++++++++- src/theory/arith/theory_arith.h | 19 ++++++-- src/theory/arith/theory_arith_private.cpp | 57 ++-------------------- src/theory/arith/theory_arith_private.h | 14 +++++- 4 files changed, 91 insertions(+), 58 deletions(-) diff --git a/src/theory/arith/theory_arith.cpp b/src/theory/arith/theory_arith.cpp index 1436198a8..4884d8484 100644 --- a/src/theory/arith/theory_arith.cpp +++ b/src/theory/arith/theory_arith.cpp @@ -21,6 +21,7 @@ #include "smt/smt_statistics_registry.h" #include "theory/arith/arith_rewriter.h" #include "theory/arith/infer_bounds.h" +#include "theory/arith/nl/nonlinear_extension.h" #include "theory/arith/theory_arith_private.h" #include "theory/ext_theory.h" @@ -42,7 +43,8 @@ TheoryArith::TheoryArith(context::Context* c, new TheoryArithPrivate(*this, c, u, out, valuation, logicInfo, pnm)), d_ppRewriteTimer("theory::arith::ppRewriteTimer"), d_astate(*d_internal, c, u, valuation), - d_inferenceManager(*this, d_astate, pnm) + d_inferenceManager(*this, d_astate, pnm), + d_nonlinearExtension(nullptr) { smtStatisticsRegistry()->registerStat(&d_ppRewriteTimer); @@ -76,6 +78,13 @@ void TheoryArith::finishInit() d_valuation.setUnevaluatedKind(kind::SINE); d_valuation.setUnevaluatedKind(kind::PI); } + // only need to create nonlinear extension if non-linear logic + const LogicInfo& logicInfo = getLogicInfo(); + if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) + { + d_nonlinearExtension.reset( + new nl::NonlinearExtension(*this, d_equalityEngine)); + } // finish initialize internally d_internal->finishInit(); } @@ -123,7 +132,53 @@ void TheoryArith::propagate(Effort e) { } bool TheoryArith::collectModelInfo(TheoryModel* m) { - return d_internal->collectModelInfo(m); + std::set termSet; + // Work out which variables are needed + const std::set& irrKinds = m->getIrrelevantKinds(); + computeAssertedTerms(termSet, irrKinds); + // this overrides behavior to not assert equality engine + return collectModelValues(m, termSet); +} + +bool TheoryArith::collectModelValues(TheoryModel* m, + const std::set& termSet) +{ + // get the model from the linear solver + std::map arithModel; + d_internal->collectModelValues(termSet, arithModel); + // if non-linear is enabled, intercept the model, which may repair its values + if (d_nonlinearExtension != nullptr) + { + // Non-linear may repair values to satisfy non-linear constraints (see + // documentation for NonlinearExtension::interceptModel). + d_nonlinearExtension->interceptModel(arithModel); + } + // We are now ready to assert the model. + for (const std::pair& p : arithModel) + { + // maps to constant of comparable type + Assert(p.first.getType().isComparableTo(p.second.getType())); + Assert(p.second.isConst()); + if (m->assertEquality(p.first, p.second, true)) + { + continue; + } + // If we failed to assert an equality, it is likely due to theory + // combination, namely the repaired model for non-linear changed + // an equality status that was agreed upon by both (linear) arithmetic + // and another theory. In this case, we must add a lemma, or otherwise + // we would terminate with an invalid model. Thus, we add a splitting + // lemma of the form ( x = v V x != v ) where v is the model value + // assigned by the non-linear solver to x. + if (d_nonlinearExtension != nullptr) + { + Node eq = p.first.eqNode(p.second); + Node lem = NodeManager::currentNM()->mkNode(kind::OR, eq, eq.negate()); + d_out->lemma(lem); + } + return false; + } + return true; } void TheoryArith::notifyRestart(){ diff --git a/src/theory/arith/theory_arith.h b/src/theory/arith/theory_arith.h index 4851f1c5d..30ad724cc 100644 --- a/src/theory/arith/theory_arith.h +++ b/src/theory/arith/theory_arith.h @@ -25,12 +25,15 @@ namespace CVC4 { namespace theory { - namespace arith { +namespace nl { +class NonlinearExtension; +} + /** - * Implementation of QF_LRA. - * Based upon: + * Implementation of linear and non-linear integer and real arithmetic. + * The linear arithmetic solver is based upon: * http://research.microsoft.com/en-us/um/people/leonardo/cav06.pdf */ class TheoryArith : public Theory { @@ -78,6 +81,11 @@ class TheoryArith : public Theory { TrustNode explain(TNode n) override; bool collectModelInfo(TheoryModel* m) override; + /** + * Collect model values in m based on the relevant terms given by termSet. + */ + bool collectModelValues(TheoryModel* m, + const std::set& termSet) override; void shutdown() override {} @@ -110,6 +118,11 @@ class TheoryArith : public Theory { /** The arith::InferenceManager. */ InferenceManager d_inferenceManager; + /** + * The non-linear extension, responsible for all approaches for non-linear + * arithmetic. + */ + std::unique_ptr d_nonlinearExtension; };/* class TheoryArith */ }/* CVC4::theory::arith namespace */ diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index 1b49b7350..8595e26b5 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -164,7 +164,6 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing, TheoryArithPrivate::~TheoryArithPrivate(){ if(d_treeLog != NULL){ delete d_treeLog; } if(d_approxStats != NULL) { delete d_approxStats; } - if(d_nonlinearExtension != NULL) { delete d_nonlinearExtension; } } TheoryRewriter* TheoryArithPrivate::getTheoryRewriter() { return &d_rewriter; } @@ -177,12 +176,7 @@ void TheoryArithPrivate::finishInit() eq::EqualityEngine* ee = d_containing.getEqualityEngine(); Assert(ee != nullptr); d_congruenceManager.finishInit(ee); - const LogicInfo& logicInfo = getLogicInfo(); - // only need to create nonlinear extension if non-linear logic - if (logicInfo.isTheoryEnabled(THEORY_ARITH) && !logicInfo.isLinear()) - { - d_nonlinearExtension = new nl::NonlinearExtension(d_containing, ee); - } + d_nonlinearExtension = d_containing.d_nonlinearExtension.get(); } static bool contains(const ConstraintCPVec& v, ConstraintP con){ @@ -4074,7 +4068,8 @@ Rational TheoryArithPrivate::deltaValueForTotalOrder() const{ return belowMin; } -bool TheoryArithPrivate::collectModelInfo(TheoryModel* m) +void TheoryArithPrivate::collectModelValues(const std::set& termSet, + std::map& arithModel) { AlwaysAssert(d_qflraStatus == Result::SAT); //AlwaysAssert(!d_nlIncomplete, "Arithmetic solver cannot currently produce models for input with nonlinear arithmetic constraints"); @@ -4085,10 +4080,6 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m) Debug("arith::collectModelInfo") << "collectModelInfo() begin " << endl; - std::set termSet; - const std::set& irrKinds = m->getIrrelevantKinds(); - d_containing.computeAssertedTerms(termSet, irrKinds, true); - // Delta lasts at least the duration of the function call const Rational& delta = d_partialModel.getDelta(); std::unordered_set shared = d_containing.currentlySharedTerms(); @@ -4096,8 +4087,6 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m) // TODO: // This is not very good for user push/pop.... // Revisit when implementing push/pop - // Map of terms to values, constructed when non-linear arithmetic is active. - std::map arithModel; for(var_iterator vi = var_begin(), vend = var_end(); vi != vend; ++vi){ ArithVar v = *vi; @@ -4112,56 +4101,20 @@ bool TheoryArithPrivate::collectModelInfo(TheoryModel* m) Node qNode = mkRationalNode(qmodel); Debug("arith::collectModelInfo") << "m->assertEquality(" << term << ", " << qmodel << ", true)" << endl; - if (d_nonlinearExtension != nullptr) - { - // Let non-linear extension inspect the values before they are sent - // to the theory model. - arithModel[term] = qNode; - } - else - { - if (!m->assertEquality(term, qNode, true)) - { - return false; - } - } + // Add to the map + arithModel[term] = qNode; }else{ Debug("arith::collectModelInfo") << "Skipping m->assertEquality(" << term << ", true)" << endl; } } } - if (d_nonlinearExtension != nullptr) - { - // Non-linear may repair values to satisfy non-linear constraints (see - // documentation for NonlinearExtension::interceptModel). - d_nonlinearExtension->interceptModel(arithModel); - // We are now ready to assert the model. - for (std::pair& p : arithModel) - { - if (!m->assertEquality(p.first, p.second, true)) - { - // If we failed to assert an equality, it is likely due to theory - // combination, namely the repaired model for non-linear changed - // an equality status that was agreed upon by both (linear) arithmetic - // and another theory. In this case, we must add a lemma, or otherwise - // we would terminate with an invalid model. Thus, we add a splitting - // lemma of the form ( x = v V x != v ) where v is the model value - // assigned by the non-linear solver to x. - Node eq = p.first.eqNode(p.second); - Node lem = NodeManager::currentNM()->mkNode(kind::OR, eq, eq.negate()); - d_containing.d_out->lemma(lem); - return false; - } - } - } // Iterate over equivalence classes in LinearEqualityModule // const eq::EqualityEngine& ee = d_congruenceManager.getEqualityEngine(); // m->assertEqualityEngine(&ee); Debug("arith::collectModelInfo") << "collectModelInfo() end " << endl; - return true; } bool TheoryArithPrivate::safeToReset() const { diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index d0428f2ef..6d030dece 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -371,7 +371,7 @@ public: FCSimplexDecisionProcedure d_fcSimplex; SumOfInfeasibilitiesSPD d_soiSimplex; AttemptSolutionSDP d_attemptSolSimplex; - + /** non-linear algebraic approach */ nl::NonlinearExtension* d_nonlinearExtension; @@ -456,6 +456,18 @@ public: Rational deltaValueForTotalOrder() const; bool collectModelInfo(TheoryModel* m); + /** + * Collect model values. This is the main method for extracting information + * about how to construct the model. This method relies on the caller for + * processing the map, which is done so that other modules (e.g. the + * non-linear extension) can modify arithModel before it is sent to the model. + * + * @param termSet The set of relevant terms + * @param arithModel Mapping from terms (of real type) to their values. The + * caller should assert equalities to the model for each entry in this map. + */ + void collectModelValues(const std::set& termSet, + std::map& arithModel); void shutdown(){ } -- 2.30.2