Eliminate more static rewrites (#7786)
authorGereon Kremer <gkremer@stanford.edu>
Fri, 10 Dec 2021 20:31:01 +0000 (12:31 -0800)
committerGitHub <noreply@github.com>
Fri, 10 Dec 2021 20:31:01 +0000 (20:31 +0000)
This PR eliminates almost all remaining static rewrites from the arithmetic theory.

12 files changed:
src/preprocessing/passes/learned_rewrite.cpp
src/theory/arith/bound_inference.cpp
src/theory/arith/bound_inference.h
src/theory/arith/constraint.cpp
src/theory/arith/infer_bounds.cpp
src/theory/arith/nl/icp/icp_solver.cpp
src/theory/arith/nl/icp/icp_solver.h
src/theory/arith/nl/nl_model.cpp
src/theory/arith/nl/nl_model.h
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h

index 642a63aa48fe916fa378de3456eae52fe9809eb6..3922525f26a24d1db0bb741a4abeb8d8477b87a7 100644 (file)
@@ -61,7 +61,7 @@ PreprocessingPassResult LearnedRewrite::applyInternal(
     AssertionPipeline* assertionsToPreprocess)
 {
   NodeManager* nm = NodeManager::currentNM();
-  arith::BoundInference binfer;
+  arith::BoundInference binfer(d_env);
   std::vector<Node> learnedLits = d_preprocContext->getLearnedLiterals();
   std::unordered_set<Node> llrw;
   std::unordered_map<TNode, Node> visited;
index cd688660acee941eac973accdb647d53cf8c83c9..4423cae616efc6a6fd5e148dfcab61091be45e98 100644 (file)
@@ -15,6 +15,7 @@
 
 #include "theory/arith/bound_inference.h"
 
+#include "smt/env.h"
 #include "theory/arith/normal_form.h"
 #include "theory/rewriter.h"
 
@@ -29,6 +30,8 @@ std::ostream& operator<<(std::ostream& os, const Bounds& b) {
             << b.upper_value << (b.upper_strict ? ')' : ']');
 }
 
+BoundInference::BoundInference(Env& env) : EnvObj(env) {}
+
 void BoundInference::reset() { d_bounds.clear(); }
 
 Bounds& BoundInference::get_or_add(const Node& lhs)
@@ -53,7 +56,7 @@ Bounds BoundInference::get(const Node& lhs) const
 const std::map<Node, Bounds>& BoundInference::get() const { return d_bounds; }
 bool BoundInference::add(const Node& n, bool onlyVariables)
 {
-  Node tmp = Rewriter::rewrite(n);
+  Node tmp = rewrite(n);
   if (tmp.getKind() == Kind::CONST_BOOLEAN)
   {
     return false;
@@ -175,19 +178,19 @@ void BoundInference::update_lower_bound(const Node& origin,
     if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
     {
       b.lower_bound = b.upper_bound =
-          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+          rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
     }
     else
     {
-      b.lower_bound = Rewriter::rewrite(
-          nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
+      b.lower_bound =
+          rewrite(nm->mkNode(strict ? Kind::GT : Kind::GEQ, lhs, value));
     }
   }
   else if (strict && b.lower_value == value)
   {
     auto* nm = NodeManager::currentNM();
     b.lower_strict = strict;
-    b.lower_bound = Rewriter::rewrite(nm->mkNode(Kind::GT, lhs, value));
+    b.lower_bound = rewrite(nm->mkNode(Kind::GT, lhs, value));
     b.lower_origin = origin;
   }
 }
@@ -210,19 +213,19 @@ void BoundInference::update_upper_bound(const Node& origin,
     if (!b.lower_strict && !b.upper_strict && b.lower_value == b.upper_value)
     {
       b.lower_bound = b.upper_bound =
-          Rewriter::rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
+          rewrite(nm->mkNode(Kind::EQUAL, lhs, value));
     }
     else
     {
-      b.upper_bound = Rewriter::rewrite(
-          nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
+      b.upper_bound =
+          rewrite(nm->mkNode(strict ? Kind::LT : Kind::LEQ, lhs, value));
     }
   }
   else if (strict && b.upper_value == value)
   {
     auto* nm = NodeManager::currentNM();
     b.upper_strict = strict;
-    b.upper_bound = Rewriter::rewrite(nm->mkNode(Kind::LT, lhs, value));
+    b.upper_bound = rewrite(nm->mkNode(Kind::LT, lhs, value));
     b.upper_origin = origin;
   }
 }
@@ -238,20 +241,6 @@ std::ostream& operator<<(std::ostream& os, const BoundInference& bi)
   return os;
 }
 
-std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions) {
-  BoundInference bi;
-  for (const auto& a: assertions) {
-    bi.add(a);
-  }
-  std::map<Node, std::pair<Node,Node>> res;
-  for (const auto& b : bi.get())
-  {
-    res.emplace(b.first,
-                std::make_pair(b.second.lower_value, b.second.upper_value));
-  }
-  return res;
-}
-
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5
index e8d7a294f3b004d77bfec8e98b19b4548ae02167..a3043ee9348a4aa4eb97ed9f5c9d9d6a9e18ff3c 100644 (file)
@@ -21,6 +21,7 @@
 #include <vector>
 
 #include "expr/node.h"
+#include "smt/env_obj.h"
 
 namespace cvc5 {
 namespace theory {
@@ -53,9 +54,10 @@ namespace arith {
    * A utility class that extracts direct bounds on arithmetic terms from theory
    * atoms.
    */
-  class BoundInference
+  class BoundInference : protected EnvObj
   {
    public:
+    BoundInference(Env& env);
     void reset();
 
     /**
@@ -110,8 +112,6 @@ namespace arith {
 /** Print the current variable bounds. */
 std::ostream& operator<<(std::ostream& os, const BoundInference& bi);
 
-std::map<Node, std::pair<Node,Node>> getBounds(const std::vector<Node>& assertions);
-
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5
index cffacdc39b4c23a3d41bfbf4c765bf68564a2178..a9576e0cc2f36bb52273c37d4c437730e5a48fcd 100644 (file)
@@ -1551,9 +1551,6 @@ TrustNode Constraint::externalExplainForPropagation(TNode lit) const
   Node n = safeConstructNary(nb);
   if (d_database->isProofEnabled())
   {
-    // Check that the literal we're explaining via this constraint actually
-    // matches the constraint's canonical literal.
-    Assert(Rewriter::rewrite(lit) == getLiteral());
     std::vector<Node> assumptions;
     if (n.getKind() == Kind::AND)
     {
index aae9bae620c8781024bf95ccf8679a24471f8d29..21f698e45c88d65e000528defc981681459ad2bc 100644 (file)
@@ -163,9 +163,7 @@ Node InferBoundsResult::getLiteral() const{
     Assert(getValue().infinitesimalSgn() >= 0);
     k = boundIsRational() ? kind::GEQ : kind::GT;
   }
-  Node atom = nm->mkNode(k, getTerm(), qnode);
-  Node lit = Rewriter::rewrite(atom);
-  return lit;
+  return nm->mkNode(k, getTerm(), qnode);
 }
 
 /* If there is a bound, this is a node that explains the bound. */
index 92c7d3ddd0e72c8725fd4daddf7acd41e8994869..aab63325e3359a291c299bed86a683e4819b1a51 100644 (file)
@@ -66,7 +66,7 @@ inline std::ostream& operator<<(std::ostream& os, const IAWrapper& iaw)
 }  // namespace
 
 ICPSolver::ICPSolver(Env& env, InferenceManager& im)
-    : EnvObj(env), d_im(im), d_state(d_mapper)
+    : EnvObj(env), d_im(im), d_state(env, d_mapper)
 {
 }
 
index 8b0fbf583724dc1fe31316fb90a3bdb6a627f2ae..b849255cc12ed4dfc9a7499b7e6df85a9c423f18 100644 (file)
@@ -86,12 +86,12 @@ class ICPSolver : protected EnvObj
     std::vector<Node> d_conflict;
 
     /** Initialized the variable bounds with a variable mapper */
-    ICPState(VariableMapper& vm) {}
+    ICPState(Env& env, VariableMapper& vm) : d_bounds(env) {}
 
     /** Reset this state */
     void reset()
     {
-      d_bounds = BoundInference();
+      d_bounds.reset();
       d_candidates.clear();
       d_assignment.clear();
       d_origins = ContractionOriginManager();
index d23ddd53dc3d67435ca02ff557a6eb8ae3097993..90138bf3eb6ad1ba22a838811be8f3f782f84bb2 100644 (file)
@@ -32,7 +32,7 @@ namespace theory {
 namespace arith {
 namespace nl {
 
-NlModel::NlModel() : d_used_approx(false)
+NlModel::NlModel(Env& env) : EnvObj(env), d_used_approx(false)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
   d_false = NodeManager::currentNM()->mkConst(false);
@@ -122,7 +122,7 @@ Node NlModel::computeModelValue(TNode n, bool isConcrete)
         children.emplace_back(computeModelValue(n[i], isConcrete));
       }
       ret = NodeManager::currentNM()->mkNode(n.getKind(), children);
-      ret = Rewriter::rewrite(ret);
+      ret = rewrite(ret);
     }
   }
   Trace("nl-ext-mv-debug") << "computed " << (isConcrete ? "M" : "M_A") << "["
@@ -246,7 +246,7 @@ bool NlModel::checkModel(const std::vector<Node>& assertions,
       // apply the substitution to a
       if (!d_substitutions.empty())
       {
-        av = Rewriter::rewrite(arithSubstitute(av, d_substitutions));
+        av = rewrite(arithSubstitute(av, d_substitutions));
       }
       // simple check literal
       if (!simpleCheckModelLit(av))
@@ -307,7 +307,7 @@ bool NlModel::addSubstitution(TNode v, TNode s)
     Node ms = arithSubstitute(sub, tmp);
     if (ms != sub)
     {
-      sub = Rewriter::rewrite(ms);
+      sub = rewrite(ms);
     }
   }
   d_substitutions.add(v, s);
@@ -376,7 +376,7 @@ bool NlModel::solveEqualitySimple(Node eq,
   if (!d_substitutions.empty())
   {
     seq = arithSubstitute(eq, d_substitutions);
-    seq = Rewriter::rewrite(seq);
+    seq = rewrite(seq);
     if (seq.isConst())
     {
       if (seq.getConst<bool>())
@@ -580,7 +580,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
       {
         lit2 = lit2.negate();
       }
-      lit2 = Rewriter::rewrite(lit2);
+      lit2 = rewrite(lit2);
       bool success = simpleCheckModelLit(lit2);
       if (success != pol)
       {
@@ -669,7 +669,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
           b = it->second;
           t = nm->mkNode(PLUS, t, nm->mkNode(MULT, b, v));
         }
-        t = Rewriter::rewrite(t);
+        t = rewrite(t);
         Trace("nl-ext-cms-debug") << "Trying to find min/max for quadratic "
                                   << t << "..." << std::endl;
         Trace("nl-ext-cms-debug") << "    a = " << a << std::endl;
@@ -677,7 +677,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
         // find maximal/minimal value on the interval
         Node apex = nm->mkNode(
             DIVISION, nm->mkNode(UMINUS, b), nm->mkNode(MULT, d_two, a));
-        apex = Rewriter::rewrite(apex);
+        apex = rewrite(apex);
         Assert(apex.isConst());
         // for lower, upper, whether we are greater than the apex
         bool cmp[2];
@@ -686,7 +686,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
         {
           boundn[r] = r == 0 ? bit->second.first : bit->second.second;
           Node cmpn = nm->mkNode(GT, boundn[r], apex);
-          cmpn = Rewriter::rewrite(cmpn);
+          cmpn = rewrite(cmpn);
           Assert(cmpn.isConst());
           cmp[r] = cmpn.getConst<bool>();
         }
@@ -717,12 +717,12 @@ bool NlModel::simpleCheckModelLit(Node lit)
             {
               qsub.d_subs.back() = boundn[r];
               Node ts = arithSubstitute(t, qsub);
-              tcmpn[r] = Rewriter::rewrite(ts);
+              tcmpn[r] = rewrite(ts);
             }
             Node tcmp = nm->mkNode(LT, tcmpn[0], tcmpn[1]);
             Trace("nl-ext-cms-debug")
                 << "  ...both sides of apex, compare " << tcmp << std::endl;
-            tcmp = Rewriter::rewrite(tcmp);
+            tcmp = rewrite(tcmp);
             Assert(tcmp.isConst());
             unsigned bindex_use = (tcmp.getConst<bool>() == pol) ? 1 : 0;
             Trace("nl-ext-cms-debug")
@@ -756,7 +756,7 @@ bool NlModel::simpleCheckModelLit(Node lit)
   if (!qsub.empty())
   {
     Node slit = arithSubstitute(lit, qsub);
-    slit = Rewriter::rewrite(slit);
+    slit = rewrite(slit);
     return simpleCheckModelLit(slit);
   }
   return false;
@@ -1003,7 +1003,7 @@ bool NlModel::simpleCheckModelMsum(const std::map<Node, Node>& msum, bool pol)
     comp = comp.negate();
   }
   Trace("nl-ext-cms") << "  comparison is : " << comp << std::endl;
-  comp = Rewriter::rewrite(comp);
+  comp = rewrite(comp);
   Assert(comp.isConst());
   Trace("nl-ext-cms") << "  returned : " << comp << std::endl;
   return comp == d_true;
@@ -1073,7 +1073,7 @@ void NlModel::getModelValueRepair(
         witness = nm->mkNode(MULT,
                              nm->mkConst(CONST_RATIONAL, Rational(1, 2)),
                              nm->mkNode(PLUS, l, u));
-        witness = Rewriter::rewrite(witness);
+        witness = rewrite(witness);
         Trace("nl-model") << v << " witness is " << witness << std::endl;
       }
       approximations[v] = std::pair<Node, Node>(pred, witness);
index 7dcd89a4ac39a8abf41d3f14b9b4b5d46c4dc60e..e195aa9b208202e1e23eb480a023b0189fd7eec6 100644 (file)
@@ -23,6 +23,7 @@
 #include "expr/kind.h"
 #include "expr/node.h"
 #include "expr/subs.h"
+#include "smt/env_obj.h"
 
 namespace cvc5 {
 
@@ -48,12 +49,12 @@ class NonlinearExtension;
  * model in the case it can determine that a model exists. These include
  * techniques based on solving (quadratic) equations and bound analysis.
  */
-class NlModel
+class NlModel : protected EnvObj
 {
   friend class NonlinearExtension;
 
  public:
-  NlModel();
+  NlModel(Env& env);
   ~NlModel();
   /**
    * This method is called once at the beginning of a last call effort check,
index e757410966ce913c66c6cf6fb31882149ed7ea30..77bb164a965b13aeeb2a9b1233046944ef3d051c 100644 (file)
@@ -48,7 +48,7 @@ NonlinearExtension::NonlinearExtension(Env& env,
       d_checkCounter(0),
       d_extTheoryCb(state.getEqualityEngine()),
       d_extTheory(env, d_extTheoryCb, d_im),
-      d_model(),
+      d_model(env),
       d_trSlv(d_env, d_im, d_model),
       d_extState(d_im, d_model, d_env),
       d_factoringSlv(d_env, &d_extState),
@@ -122,7 +122,7 @@ void NonlinearExtension::getAssertions(std::vector<Node>& assertions)
   }
   Valuation v = d_containing.getValuation();
 
-  BoundInference bounds;
+  BoundInference bounds(d_env);
 
   std::unordered_set<Node> init_assertions;
 
index 643ba9a28bf103200d067d8f3c9aef5d719ff16c..bf8798485078c895aedcc408292c3114b1e8876a 100644 (file)
@@ -4782,8 +4782,11 @@ std::pair<bool, Node> TheoryArithPrivate::entailmentCheck(TNode lit, const Arith
   return make_pair(false, Node::null());
 }
 
-bool TheoryArithPrivate::decomposeTerm(Node term, Rational& m, Node& p, Rational& c){
-  Node t = Rewriter::rewrite(term);
+bool TheoryArithPrivate::decomposeTerm(Node t,
+                                       Rational& m,
+                                       Node& p,
+                                       Rational& c)
+{
   if(!Polynomial::isMember(t)){
     return false;
   }
@@ -4879,12 +4882,13 @@ bool TheoryArithPrivate::decomposeLiteral(Node lit, Kind& k, int& dir, Rational&
   // left : lm*( lp ) + lc
   // right: rm*( rp ) + rc
   Rational lc, rc;
-  bool success = decomposeTerm(left, lm, lp, lc);
+  bool success = decomposeTerm(rewrite(left), lm, lp, lc);
   if(!success){ return false; }
-  success = decomposeTerm(right, rm, rp, rc);
+  success = decomposeTerm(rewrite(right), rm, rp, rc);
   if(!success){ return false; }
 
-  Node diff = Rewriter::rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right));
+  Node diff =
+      rewrite(NodeManager::currentNM()->mkNode(kind::MINUS, left, right));
   Rational dc;
   success = decomposeTerm(diff, dm, dp, dc);
   Assert(success);
index 918f73a534d43a7c21d7464271c462c49abbc94a..7c90352d2b315a65f19ed5ab723cb2523670964d 100644 (file)
@@ -162,20 +162,26 @@ private:
   //std::pair<DeltaRational, Node> inferBound(TNode term, bool lb, int maxRounds = -1, const DeltaRational* threshold = NULL);
 
 private:
-  static bool decomposeTerm(Node term, Rational& m, Node& p, Rational& c);
-  static bool decomposeLiteral(Node lit, Kind& k, int& dir, Rational& lm,  Node& lp, Rational& rm, Node& rp, Rational& dm, Node& dp, DeltaRational& sep);
-  static void setToMin(int sgn, std::pair<Node, DeltaRational>& min, const std::pair<Node, DeltaRational>& e);
-
-  /**
-   * The map between arith variables to nodes.
-   */
-  //ArithVarNodeMap d_arithvarNodeMap;
-
-  typedef ArithVariables::var_iterator var_iterator;
-  var_iterator var_begin() const { return d_partialModel.var_begin(); }
-  var_iterator var_end() const { return d_partialModel.var_end(); }
-
-  NodeSet d_setupNodes;
+ static bool decomposeTerm(Node t, Rational& m, Node& p, Rational& c);
+ bool decomposeLiteral(Node lit,
+                       Kind& k,
+                       int& dir,
+                       Rational& lm,
+                       Node& lp,
+                       Rational& rm,
+                       Node& rp,
+                       Rational& dm,
+                       Node& dp,
+                       DeltaRational& sep);
+ static void setToMin(int sgn,
+                      std::pair<Node, DeltaRational>& min,
+                      const std::pair<Node, DeltaRational>& e);
+
+ typedef ArithVariables::var_iterator var_iterator;
+ var_iterator var_begin() const { return d_partialModel.var_begin(); }
+ var_iterator var_end() const { return d_partialModel.var_end(); }
+
+ NodeSet d_setupNodes;
 public:
   bool isSetup(Node n) const {
     return d_setupNodes.find(n) != d_setupNodes.end();