This PR eliminates almost all remaining static rewrites from the arithmetic theory.
AssertionPipeline* assertionsToPreprocess)
{
NodeManager* nm = NodeManager::currentNM();
- arith::BoundInference binfer;
+ arith::BoundInference binfer(d_env);
std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
std::unordered_set<Node> llrw;
std::unordered_map<TNode, Node> visited;
#include "theory/arith/bound_inference.h"
+#include "smt/env.h"
#include "theory/arith/normal_form.h"
#include "theory/rewriter.h"
<< b.upper_value << (b.upper_strict ? ')' : ']');
}
+BoundInference::BoundInference(Env& env) : EnvObj(env) {}
+
void BoundInference::reset() { d_bounds.clear(); }
Bounds& BoundInference::get_or_add(const Node& lhs)
const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
bool BoundInference::add(const Node& n, bool onlyVariables)
{
- Node tmp = Rewriter::rewrite(n);
+ Node tmp = rewrite(n);
if (tmp.getKind() == Kind::CONST_BOOLEAN)
{
return false;
if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
{
b.lower_bound = b.upper_bound =
- Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+ rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
}
else
{
- b.lower_bound = Rewriter::rewrite(
- nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
+ b.lower_bound =
+ rewrite(nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
}
}
else if (strict && b.lower_value == value)
{
auto* nm = NodeManager::currentNM();
b.lower_strict = strict;
- b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
+ b.lower_bound = rewrite(nm->mkNode(Kind::GT, lhs, value));
b.lower_origin = origin;
}
}
if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
{
b.lower_bound = b.upper_bound =
- Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+ rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
}
else
{
- b.upper_bound = Rewriter::rewrite(
- nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
+ b.upper_bound =
+ rewrite(nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
}
}
else if (strict && b.upper_value == value)
{
auto* nm = NodeManager::currentNM();
b.upper_strict = strict;
- b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
+ b.upper_bound = rewrite(nm->mkNode(Kind::LT, lhs, value));
b.upper_origin = origin;
}
}
return os;
}
-std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) {
- BoundInference bi;
- for (const auto& a: assertions) {
- bi.add(a);
- }
- std::map<Node, std::pair<Node,Node>> res;
- for (const auto& b : bi.get())
- {
- res.emplace(b.first,
- std::make_pair(b.second.lower_value, b.second.upper_value));
- }
- return res;
-}
-
} // namespace arith
} // namespace theory
} // namespace cvc5
#include <vector>
#include "expr/node.h"
+#include "smt/env_obj.h"
namespace cvc5 {
namespace theory {
* A utility class that extracts direct bounds on arithmetic terms from theory
* atoms.
*/
- class BoundInference
+ class BoundInference : protected EnvObj
{
public:
+ BoundInference(Env& env);
void reset();
/**
/** Print the current variable bounds. */
std::ostream& operator<<(std::ostream& os, const BoundInference& bi);
-std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions);
-
} // namespace arith
} // namespace theory
} // namespace cvc5
Node n = safeConstructNary(nb);
if (d_database->isProofEnabled())
{
- // Check that the literal we're explaining via this constraint actually
- // matches the constraint's canonical literal.
- Assert(Rewriter::rewrite(lit) == getLiteral());
std::vector<Node> assumptions;
if (n.getKind() == Kind::AND)
{
Assert(getValue().infinitesimalSgn() >= 0);
k = boundIsRational() ? kind::GEQ : kind::GT;
}
- Node atom = nm->mkNode(k, getTerm(), qnode);
- Node lit = Rewriter::rewrite(atom);
- return lit;
+ return nm->mkNode(k, getTerm(), qnode);
}
/* If there is a bound, this is a node that explains the bound. */
} // namespace
ICPSolver::ICPSolver(Env& env, InferenceManager& im)
- : EnvObj(env), d_im(im), d_state(d_mapper)
+ : EnvObj(env), d_im(im), d_state(env, d_mapper)
{
}
std::vector<Node> d_conflict;
/** Initialized the variable bounds with a variable mapper */
- ICPState(VariableMapper& vm) {}
+ ICPState(Env& env, VariableMapper& vm) : d_bounds(env) {}
/** Reset this state */
void reset()
{
- d_bounds = BoundInference();
+ d_bounds.reset();
d_candidates.clear();
d_assignment.clear();
d_origins = ContractionOriginManager();
namespace arith {
namespace nl {
-NlModel::NlModel() : d_used_approx(false)
+NlModel::NlModel(Env& env) : EnvObj(env), d_used_approx(false)
{
d_true = NodeManager::currentNM()->mkConst(true);
d_false = NodeManager::currentNM()->mkConst(false);
children.emplace_back(computeModelValue(n[i], isConcrete));
}
ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
- ret = Rewriter::rewrite(ret);
+ ret = rewrite(ret);
}
}
Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "["
// apply the substitution to a
if (!d_substitutions.empty())
{
- av = Rewriter::rewrite(arithSubstitute(av, d_substitutions));
+ av = rewrite(arithSubstitute(av, d_substitutions));
}
// simple check literal
if (!simpleCheckModelLit(av))
Node ms = arithSubstitute(sub, tmp);
if (ms != sub)
{
- sub = Rewriter::rewrite(ms);
+ sub = rewrite(ms);
}
}
d_substitutions.add(v, s);
if (!d_substitutions.empty())
{
seq = arithSubstitute(eq, d_substitutions);
- seq = Rewriter::rewrite(seq);
+ seq = rewrite(seq);
if (seq.isConst())
{
if (seq.getConst<bool>())
{
lit2 = lit2.negate();
}
- lit2 = Rewriter::rewrite(lit2);
+ lit2 = rewrite(lit2);
bool success = simpleCheckModelLit(lit2);
if (success != pol)
{
b = it->second;
t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
}
- t = Rewriter::rewrite(t);
+ t = rewrite(t);
Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
<< t << "..." << std::endl;
Trace("nl-ext-cms-debug") << " a = " << a << std::endl;
// find maximal/minimal value on the interval
Node apex = nm->mkNode(
DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
- apex = Rewriter::rewrite(apex);
+ apex = rewrite(apex);
Assert(apex.isConst());
// for lower, upper, whether we are greater than the apex
bool cmp[2];
{
boundn[r] = r == 0 ? bit->second.first : bit->second.second;
Node cmpn = nm->mkNode(GT, boundn[r], apex);
- cmpn = Rewriter::rewrite(cmpn);
+ cmpn = rewrite(cmpn);
Assert(cmpn.isConst());
cmp[r] = cmpn.getConst<bool>();
}
{
qsub.d_subs.back() = boundn[r];
Node ts = arithSubstitute(t, qsub);
- tcmpn[r] = Rewriter::rewrite(ts);
+ tcmpn[r] = rewrite(ts);
}
Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
Trace("nl-ext-cms-debug")
<< " ...both sides of apex, compare " << tcmp << std::endl;
- tcmp = Rewriter::rewrite(tcmp);
+ tcmp = rewrite(tcmp);
Assert(tcmp.isConst());
unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
Trace("nl-ext-cms-debug")
if (!qsub.empty())
{
Node slit = arithSubstitute(lit, qsub);
- slit = Rewriter::rewrite(slit);
+ slit = rewrite(slit);
return simpleCheckModelLit(slit);
}
return false;
comp = comp.negate();
}
Trace("nl-ext-cms") << " comparison is : " << comp << std::endl;
- comp = Rewriter::rewrite(comp);
+ comp = rewrite(comp);
Assert(comp.isConst());
Trace("nl-ext-cms") << " returned : " << comp << std::endl;
return comp == d_true;
witness = nm->mkNode(MULT,
nm->mkConst(CONST_RATIONAL, Rational(1, 2)),
nm->mkNode(PLUS, l, u));
- witness = Rewriter::rewrite(witness);
+ witness = rewrite(witness);
Trace("nl-model") << v << " witness is " << witness << std::endl;
}
approximations[v] = std::pair<Node, Node>(pred, witness);
#include "expr/kind.h"
#include "expr/node.h"
#include "expr/subs.h"
+#include "smt/env_obj.h"
namespace cvc5 {
* model in the case it can determine that a model exists. These include
* techniques based on solving (quadratic) equations and bound analysis.
*/
-class NlModel
+class NlModel : protected EnvObj
{
friend class NonlinearExtension;
public:
- NlModel();
+ NlModel(Env& env);
~NlModel();
/**
* This method is called once at the beginning of a last call effort check,
d_checkCounter(0),
d_extTheoryCb(state.getEqualityEngine()),
d_extTheory(env, d_extTheoryCb, d_im),
- d_model(),
+ d_model(env),
d_trSlv(d_env, d_im, d_model),
d_extState(d_im, d_model, d_env),
d_factoringSlv(d_env, &d_extState),
}
Valuation v = d_containing.getValuation();
- BoundInference bounds;
+ BoundInference bounds(d_env);
std::unordered_set<Node> init_assertions;
return make_pair(false, Node::null());
}
-bool TheoryArithPrivate::decomposeTerm(Node term, Rational& m, Node& p, Rational& c){
- Node t = Rewriter::rewrite(term);
+bool TheoryArithPrivate::decomposeTerm(Node t,
+ Rational& m,
+ Node& p,
+ Rational& c)
+{
if(!Polynomial::isMember(t)){
return false;
}
// left : lm*( lp ) + lc
// right: rm*( rp ) + rc
Rational lc, rc;
- bool success = decomposeTerm(left, lm, lp, lc);
+ bool success = decomposeTerm(rewrite(left), lm, lp, lc);
if(!success){ return false; }
- success = decomposeTerm(right, rm, rp, rc);
+ success = decomposeTerm(rewrite(right), rm, rp, rc);
if(!success){ return false; }
- Node diff = Rewriter::rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right));
+ Node diff =
+ rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right));
Rational dc;
success = decomposeTerm(diff, dm, dp, dc);
Assert(success);
//std::pair<DeltaRational, Node> inferBound(TNode term, bool lb, int maxRounds = -1, const DeltaRational* threshold = NULL);
private:
- static bool decomposeTerm(Node term, Rational& m, Node& p, Rational& c);
- static bool decomposeLiteral(Node lit, Kind& k, int& dir, Rational& lm, Node& lp, Rational& rm, Node& rp, Rational& dm, Node& dp, DeltaRational& sep);
- static void setToMin(int sgn, std::pair<Node, DeltaRational>& min, const std::pair<Node, DeltaRational>& e);
-
- /**
- * The map between arith variables to nodes.
- */
- //ArithVarNodeMap d_arithvarNodeMap;
-
- typedef ArithVariables::var_iterator var_iterator;
- var_iterator var_begin() const { return d_partialModel.var_begin(); }
- var_iterator var_end() const { return d_partialModel.var_end(); }
-
- NodeSet d_setupNodes;
+ static bool decomposeTerm(Node t, Rational& m, Node& p, Rational& c);
+ bool decomposeLiteral(Node lit,
+ Kind& k,
+ int& dir,
+ Rational& lm,
+ Node& lp,
+ Rational& rm,
+ Node& rp,
+ Rational& dm,
+ Node& dp,
+ DeltaRational& sep);
+ static void setToMin(int sgn,
+ std::pair<Node, DeltaRational>& min,
+ const std::pair<Node, DeltaRational>& e);
+
+ typedef ArithVariables::var_iterator var_iterator;
+ var_iterator var_begin() const { return d_partialModel.var_begin(); }
+ var_iterator var_end() const { return d_partialModel.var_end(); }
+
+ NodeSet d_setupNodes;
public:
bool isSetup(Node n) const {
return d_setupNodes.find(n) != d_setupNodes.end();