From: Gereon Kremer Date: Fri, 10 Dec 2021 20:31:01 +0000 (-0800) Subject: Eliminate more static rewrites (#7786) X-Git-Tag: cvc5-1.0.0~685 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=b14dddb404897200630c6ee1afeb98a0a24f99e0;p=cvc5.git Eliminate more static rewrites (#7786) This PR eliminates almost all remaining static rewrites from the arithmetic theory. --- diff --git a/src/preprocessing/passes/learned_rewrite.cpp b/src/preprocessing/passes/learned_rewrite.cpp index 642a63aa4..3922525f2 100644 --- a/src/preprocessing/passes/learned_rewrite.cpp +++ b/src/preprocessing/passes/learned_rewrite.cpp @@ -61,7 +61,7 @@ PreprocessingPassResult LearnedRewrite::applyInternal( AssertionPipeline* assertionsToPreprocess) { NodeManager* nm = NodeManager::currentNM(); - arith::BoundInference binfer; + arith::BoundInference binfer(d_env); std::vector learnedLits = d_preprocContext->getLearnedLiterals(); std::unordered_set llrw; std::unordered_map visited; diff --git a/src/theory/arith/bound_inference.cpp b/src/theory/arith/bound_inference.cpp index cd688660a..4423cae61 100644 --- a/src/theory/arith/bound_inference.cpp +++ b/src/theory/arith/bound_inference.cpp @@ -15,6 +15,7 @@ #include "theory/arith/bound_inference.h" +#include "smt/env.h" #include "theory/arith/normal_form.h" #include "theory/rewriter.h" @@ -29,6 +30,8 @@ std::ostream& operator<<(std::ostream& os, const Bounds& b) { << b.upper_value << (b.upper_strict ? ')' : ']'); } +BoundInference::BoundInference(Env& env) : EnvObj(env) {} + void BoundInference::reset() { d_bounds.clear(); } Bounds& BoundInference::get_or_add(const Node& lhs) @@ -53,7 +56,7 @@ Bounds BoundInference::get(const Node& lhs) const const std::map& BoundInference::get() const { return d_bounds; } bool BoundInference::add(const Node& n, bool onlyVariables) { - Node tmp = Rewriter::rewrite(n); + Node tmp = rewrite(n); if (tmp.getKind() == Kind::CONST_BOOLEAN) { return false; @@ -175,19 +178,19 @@ void BoundInference::update_lower_bound(const Node& origin, if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) { b.lower_bound = b.upper_bound = - Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); + rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); } else { - b.lower_bound = Rewriter::rewrite( - nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value)); + b.lower_bound = + rewrite(nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value)); } } else if (strict && b.lower_value == value) { auto* nm = NodeManager::currentNM(); b.lower_strict = strict; - b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value)); + b.lower_bound = rewrite(nm->mkNode(Kind::GT, lhs, value)); b.lower_origin = origin; } } @@ -210,19 +213,19 @@ void BoundInference::update_upper_bound(const Node& origin, if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value) { b.lower_bound = b.upper_bound = - Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); + rewrite(nm->mkNode(Kind::EQUAL, lhs, value)); } else { - b.upper_bound = Rewriter::rewrite( - nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value)); + b.upper_bound = + rewrite(nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value)); } } else if (strict && b.upper_value == value) { auto* nm = NodeManager::currentNM(); b.upper_strict = strict; - b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value)); + b.upper_bound = rewrite(nm->mkNode(Kind::LT, lhs, value)); b.upper_origin = origin; } } @@ -238,20 +241,6 @@ std::ostream& operator<<(std::ostream& os, const BoundInference& bi) return os; } -std::map> getBounds(const std::vector& assertions) { - BoundInference bi; - for (const auto& a: assertions) { - bi.add(a); - } - std::map> res; - for (const auto& b : bi.get()) - { - res.emplace(b.first, - std::make_pair(b.second.lower_value, b.second.upper_value)); - } - return res; -} - } // namespace arith } // namespace theory } // namespace cvc5 diff --git a/src/theory/arith/bound_inference.h b/src/theory/arith/bound_inference.h index e8d7a294f..a3043ee93 100644 --- a/src/theory/arith/bound_inference.h +++ b/src/theory/arith/bound_inference.h @@ -21,6 +21,7 @@ #include #include "expr/node.h" +#include "smt/env_obj.h" namespace cvc5 { namespace theory { @@ -53,9 +54,10 @@ namespace arith { * A utility class that extracts direct bounds on arithmetic terms from theory * atoms. */ - class BoundInference + class BoundInference : protected EnvObj { public: + BoundInference(Env& env); void reset(); /** @@ -110,8 +112,6 @@ namespace arith { /** Print the current variable bounds. */ std::ostream& operator<<(std::ostream& os, const BoundInference& bi); -std::map> getBounds(const std::vector& assertions); - } // namespace arith } // namespace theory } // namespace cvc5 diff --git a/src/theory/arith/constraint.cpp b/src/theory/arith/constraint.cpp index cffacdc39..a9576e0cc 100644 --- a/src/theory/arith/constraint.cpp +++ b/src/theory/arith/constraint.cpp @@ -1551,9 +1551,6 @@ TrustNode Constraint::externalExplainForPropagation(TNode lit) const Node n = safeConstructNary(nb); if (d_database->isProofEnabled()) { - // Check that the literal we're explaining via this constraint actually - // matches the constraint's canonical literal. - Assert(Rewriter::rewrite(lit) == getLiteral()); std::vector assumptions; if (n.getKind() == Kind::AND) { diff --git a/src/theory/arith/infer_bounds.cpp b/src/theory/arith/infer_bounds.cpp index aae9bae62..21f698e45 100644 --- a/src/theory/arith/infer_bounds.cpp +++ b/src/theory/arith/infer_bounds.cpp @@ -163,9 +163,7 @@ Node InferBoundsResult::getLiteral() const{ Assert(getValue().infinitesimalSgn() >= 0); k = boundIsRational() ? kind::GEQ : kind::GT; } - Node atom = nm->mkNode(k, getTerm(), qnode); - Node lit = Rewriter::rewrite(atom); - return lit; + return nm->mkNode(k, getTerm(), qnode); } /* If there is a bound, this is a node that explains the bound. */ diff --git a/src/theory/arith/nl/icp/icp_solver.cpp b/src/theory/arith/nl/icp/icp_solver.cpp index 92c7d3ddd..aab63325e 100644 --- a/src/theory/arith/nl/icp/icp_solver.cpp +++ b/src/theory/arith/nl/icp/icp_solver.cpp @@ -66,7 +66,7 @@ inline std::ostream& operator<<(std::ostream& os, const IAWrapper& iaw) } // namespace ICPSolver::ICPSolver(Env& env, InferenceManager& im) - : EnvObj(env), d_im(im), d_state(d_mapper) + : EnvObj(env), d_im(im), d_state(env, d_mapper) { } diff --git a/src/theory/arith/nl/icp/icp_solver.h b/src/theory/arith/nl/icp/icp_solver.h index 8b0fbf583..b849255cc 100644 --- a/src/theory/arith/nl/icp/icp_solver.h +++ b/src/theory/arith/nl/icp/icp_solver.h @@ -86,12 +86,12 @@ class ICPSolver : protected EnvObj std::vector d_conflict; /** Initialized the variable bounds with a variable mapper */ - ICPState(VariableMapper& vm) {} + ICPState(Env& env, VariableMapper& vm) : d_bounds(env) {} /** Reset this state */ void reset() { - d_bounds = BoundInference(); + d_bounds.reset(); d_candidates.clear(); d_assignment.clear(); d_origins = ContractionOriginManager(); diff --git a/src/theory/arith/nl/nl_model.cpp b/src/theory/arith/nl/nl_model.cpp index d23ddd53d..90138bf3e 100644 --- a/src/theory/arith/nl/nl_model.cpp +++ b/src/theory/arith/nl/nl_model.cpp @@ -32,7 +32,7 @@ namespace theory { namespace arith { namespace nl { -NlModel::NlModel() : d_used_approx(false) +NlModel::NlModel(Env& env) : EnvObj(env), d_used_approx(false) { d_true = NodeManager::currentNM()->mkConst(true); d_false = NodeManager::currentNM()->mkConst(false); @@ -122,7 +122,7 @@ Node NlModel::computeModelValue(TNode n, bool isConcrete) children.emplace_back(computeModelValue(n[i], isConcrete)); } ret = NodeManager::currentNM()->mkNode(n.getKind(), children); - ret = Rewriter::rewrite(ret); + ret = rewrite(ret); } } Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "[" @@ -246,7 +246,7 @@ bool NlModel::checkModel(const std::vector& assertions, // apply the substitution to a if (!d_substitutions.empty()) { - av = Rewriter::rewrite(arithSubstitute(av, d_substitutions)); + av = rewrite(arithSubstitute(av, d_substitutions)); } // simple check literal if (!simpleCheckModelLit(av)) @@ -307,7 +307,7 @@ bool NlModel::addSubstitution(TNode v, TNode s) Node ms = arithSubstitute(sub, tmp); if (ms != sub) { - sub = Rewriter::rewrite(ms); + sub = rewrite(ms); } } d_substitutions.add(v, s); @@ -376,7 +376,7 @@ bool NlModel::solveEqualitySimple(Node eq, if (!d_substitutions.empty()) { seq = arithSubstitute(eq, d_substitutions); - seq = Rewriter::rewrite(seq); + seq = rewrite(seq); if (seq.isConst()) { if (seq.getConst()) @@ -580,7 +580,7 @@ bool NlModel::simpleCheckModelLit(Node lit) { lit2 = lit2.negate(); } - lit2 = Rewriter::rewrite(lit2); + lit2 = rewrite(lit2); bool success = simpleCheckModelLit(lit2); if (success != pol) { @@ -669,7 +669,7 @@ bool NlModel::simpleCheckModelLit(Node lit) b = it->second; t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v)); } - t = Rewriter::rewrite(t); + t = rewrite(t); Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic " << t << "..." << std::endl; Trace("nl-ext-cms-debug") << " a = " << a << std::endl; @@ -677,7 +677,7 @@ bool NlModel::simpleCheckModelLit(Node lit) // find maximal/minimal value on the interval Node apex = nm->mkNode( DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a)); - apex = Rewriter::rewrite(apex); + apex = rewrite(apex); Assert(apex.isConst()); // for lower, upper, whether we are greater than the apex bool cmp[2]; @@ -686,7 +686,7 @@ bool NlModel::simpleCheckModelLit(Node lit) { boundn[r] = r == 0 ? bit->second.first : bit->second.second; Node cmpn = nm->mkNode(GT, boundn[r], apex); - cmpn = Rewriter::rewrite(cmpn); + cmpn = rewrite(cmpn); Assert(cmpn.isConst()); cmp[r] = cmpn.getConst(); } @@ -717,12 +717,12 @@ bool NlModel::simpleCheckModelLit(Node lit) { qsub.d_subs.back() = boundn[r]; Node ts = arithSubstitute(t, qsub); - tcmpn[r] = Rewriter::rewrite(ts); + tcmpn[r] = rewrite(ts); } Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]); Trace("nl-ext-cms-debug") << " ...both sides of apex, compare " << tcmp << std::endl; - tcmp = Rewriter::rewrite(tcmp); + tcmp = rewrite(tcmp); Assert(tcmp.isConst()); unsigned bindex_use = (tcmp.getConst() == pol) ? 1 : 0; Trace("nl-ext-cms-debug") @@ -756,7 +756,7 @@ bool NlModel::simpleCheckModelLit(Node lit) if (!qsub.empty()) { Node slit = arithSubstitute(lit, qsub); - slit = Rewriter::rewrite(slit); + slit = rewrite(slit); return simpleCheckModelLit(slit); } return false; @@ -1003,7 +1003,7 @@ bool NlModel::simpleCheckModelMsum(const std::map& msum, bool pol) comp = comp.negate(); } Trace("nl-ext-cms") << " comparison is : " << comp << std::endl; - comp = Rewriter::rewrite(comp); + comp = rewrite(comp); Assert(comp.isConst()); Trace("nl-ext-cms") << " returned : " << comp << std::endl; return comp == d_true; @@ -1073,7 +1073,7 @@ void NlModel::getModelValueRepair( witness = nm->mkNode(MULT, nm->mkConst(CONST_RATIONAL, Rational(1, 2)), nm->mkNode(PLUS, l, u)); - witness = Rewriter::rewrite(witness); + witness = rewrite(witness); Trace("nl-model") << v << " witness is " << witness << std::endl; } approximations[v] = std::pair(pred, witness); diff --git a/src/theory/arith/nl/nl_model.h b/src/theory/arith/nl/nl_model.h index 7dcd89a4a..e195aa9b2 100644 --- a/src/theory/arith/nl/nl_model.h +++ b/src/theory/arith/nl/nl_model.h @@ -23,6 +23,7 @@ #include "expr/kind.h" #include "expr/node.h" #include "expr/subs.h" +#include "smt/env_obj.h" namespace cvc5 { @@ -48,12 +49,12 @@ class NonlinearExtension; * model in the case it can determine that a model exists. These include * techniques based on solving (quadratic) equations and bound analysis. */ -class NlModel +class NlModel : protected EnvObj { friend class NonlinearExtension; public: - NlModel(); + NlModel(Env& env); ~NlModel(); /** * This method is called once at the beginning of a last call effort check, diff --git a/src/theory/arith/nl/nonlinear_extension.cpp b/src/theory/arith/nl/nonlinear_extension.cpp index e75741096..77bb164a9 100644 --- a/src/theory/arith/nl/nonlinear_extension.cpp +++ b/src/theory/arith/nl/nonlinear_extension.cpp @@ -48,7 +48,7 @@ NonlinearExtension::NonlinearExtension(Env& env, d_checkCounter(0), d_extTheoryCb(state.getEqualityEngine()), d_extTheory(env, d_extTheoryCb, d_im), - d_model(), + d_model(env), d_trSlv(d_env, d_im, d_model), d_extState(d_im, d_model, d_env), d_factoringSlv(d_env, &d_extState), @@ -122,7 +122,7 @@ void NonlinearExtension::getAssertions(std::vector& assertions) } Valuation v = d_containing.getValuation(); - BoundInference bounds; + BoundInference bounds(d_env); std::unordered_set init_assertions; diff --git a/src/theory/arith/theory_arith_private.cpp b/src/theory/arith/theory_arith_private.cpp index 643ba9a28..bf8798485 100644 --- a/src/theory/arith/theory_arith_private.cpp +++ b/src/theory/arith/theory_arith_private.cpp @@ -4782,8 +4782,11 @@ std::pair TheoryArithPrivate::entailmentCheck(TNode lit, const Arith return make_pair(false, Node::null()); } -bool TheoryArithPrivate::decomposeTerm(Node term, Rational& m, Node& p, Rational& c){ - Node t = Rewriter::rewrite(term); +bool TheoryArithPrivate::decomposeTerm(Node t, + Rational& m, + Node& p, + Rational& c) +{ if(!Polynomial::isMember(t)){ return false; } @@ -4879,12 +4882,13 @@ bool TheoryArithPrivate::decomposeLiteral(Node lit, Kind& k, int& dir, Rational& // left : lm*( lp ) + lc // right: rm*( rp ) + rc Rational lc, rc; - bool success = decomposeTerm(left, lm, lp, lc); + bool success = decomposeTerm(rewrite(left), lm, lp, lc); if(!success){ return false; } - success = decomposeTerm(right, rm, rp, rc); + success = decomposeTerm(rewrite(right), rm, rp, rc); if(!success){ return false; } - Node diff = Rewriter::rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right)); + Node diff = + rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right)); Rational dc; success = decomposeTerm(diff, dm, dp, dc); Assert(success); diff --git a/src/theory/arith/theory_arith_private.h b/src/theory/arith/theory_arith_private.h index 918f73a53..7c90352d2 100644 --- a/src/theory/arith/theory_arith_private.h +++ b/src/theory/arith/theory_arith_private.h @@ -162,20 +162,26 @@ private: //std::pair inferBound(TNode term, bool lb, int maxRounds = -1, const DeltaRational* threshold = NULL); private: - static bool decomposeTerm(Node term, Rational& m, Node& p, Rational& c); - static bool decomposeLiteral(Node lit, Kind& k, int& dir, Rational& lm, Node& lp, Rational& rm, Node& rp, Rational& dm, Node& dp, DeltaRational& sep); - static void setToMin(int sgn, std::pair& min, const std::pair& e); - - /** - * The map between arith variables to nodes. - */ - //ArithVarNodeMap d_arithvarNodeMap; - - typedef ArithVariables::var_iterator var_iterator; - var_iterator var_begin() const { return d_partialModel.var_begin(); } - var_iterator var_end() const { return d_partialModel.var_end(); } - - NodeSet d_setupNodes; + static bool decomposeTerm(Node t, Rational& m, Node& p, Rational& c); + bool decomposeLiteral(Node lit, + Kind& k, + int& dir, + Rational& lm, + Node& lp, + Rational& rm, + Node& rp, + Rational& dm, + Node& dp, + DeltaRational& sep); + static void setToMin(int sgn, + std::pair& min, + const std::pair& e); + + typedef ArithVariables::var_iterator var_iterator; + var_iterator var_begin() const { return d_partialModel.var_begin(); } + var_iterator var_end() const { return d_partialModel.var_end(); } + + NodeSet d_setupNodes; public: bool isSetup(Node n) const { return d_setupNodes.find(n) != d_setupNodes.end();